[LIBXGBOOST] pass demo running.
This commit is contained in:
528
src/c_api/c_api.cc
Normal file
528
src/c_api/c_api.cc
Normal file
@@ -0,0 +1,528 @@
|
||||
// Copyright (c) 2014 by Contributors
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
|
||||
#include "./c_api_error.h"
|
||||
#include "../data/simple_csr_source.h"
|
||||
#include "../common/thread_local.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/group_data.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
// booster wrapper for backward compatible reason.
|
||||
class Booster {
|
||||
public:
|
||||
explicit Booster(const std::vector<DMatrix*>& cache_mats)
|
||||
: configured_(false),
|
||||
initialized_(false),
|
||||
learner_(Learner::Create(cache_mats)) {}
|
||||
|
||||
inline Learner* learner() {
|
||||
return learner_.get();
|
||||
}
|
||||
|
||||
inline void SetParam(const std::string& name, const std::string& val) {
|
||||
cfg_.push_back(std::make_pair(name, val));
|
||||
if (configured_) {
|
||||
learner_->Configure(cfg_);
|
||||
}
|
||||
}
|
||||
|
||||
inline void LazyInit() {
|
||||
if (!configured_) {
|
||||
learner_->Configure(cfg_);
|
||||
configured_ = true;
|
||||
}
|
||||
if (!initialized_) {
|
||||
learner_->InitModel();
|
||||
initialized_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
inline void LoadModel(dmlc::Stream* fi) {
|
||||
learner_->Load(fi);
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
public:
|
||||
bool configured_;
|
||||
bool initialized_;
|
||||
std::unique_ptr<Learner> learner_;
|
||||
std::vector<std::pair<std::string, std::string> > cfg_;
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
std::string ret_str;
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<bst_gpair> tmp_gpair;
|
||||
};
|
||||
|
||||
// define the threadlocal store.
|
||||
typedef xgboost::common::ThreadLocalStore<XGBAPIThreadLocalEntry> XGBAPIThreadLocalStore;
|
||||
|
||||
int XGDMatrixCreateFromFile(const char *fname,
|
||||
int silent,
|
||||
DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
*out = DMatrix::Load(
|
||||
fname, silent != 0, false);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixCreateFromCSR(const bst_ulong* indptr,
|
||||
const unsigned *indices,
|
||||
const float* data,
|
||||
bst_ulong nindptr,
|
||||
bst_ulong nelem,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
mat.row_ptr_.resize(nindptr);
|
||||
for (bst_ulong i = 0; i < nindptr; ++i) {
|
||||
mat.row_ptr_[i] = static_cast<size_t>(indptr[i]);
|
||||
}
|
||||
mat.row_data_.resize(nelem);
|
||||
for (bst_ulong i = 0; i < nelem; ++i) {
|
||||
mat.row_data_[i] = RowBatch::Entry(indices[i], data[i]);
|
||||
mat.info.num_col = std::max(mat.info.num_col,
|
||||
static_cast<size_t>(indices[i] + 1));
|
||||
}
|
||||
mat.info.num_row = nindptr - 1;
|
||||
mat.info.num_nonzero = static_cast<uint64_t>(nelem);
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixCreateFromCSC(const bst_ulong* col_ptr,
|
||||
const unsigned* indices,
|
||||
const float* data,
|
||||
bst_ulong nindptr,
|
||||
bst_ulong nelem,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
common::ParallelGroupBuilder<RowBatch::Entry> builder(&mat.row_ptr_, &mat.row_data_);
|
||||
builder.InitBudget(0, nthread);
|
||||
long ncol = static_cast<long>(nindptr - 1); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < ncol; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
for (unsigned j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
||||
builder.AddBudget(indices[j], tid);
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < ncol; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
for (unsigned j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
||||
builder.Push(indices[j],
|
||||
RowBatch::Entry(static_cast<bst_uint>(i), data[j]),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
mat.info.num_row = mat.row_ptr_.size() - 1;
|
||||
mat.info.num_col = static_cast<uint64_t>(ncol);
|
||||
mat.info.num_nonzero = nelem;
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixCreateFromMat(const float* data,
|
||||
bst_ulong nrow,
|
||||
bst_ulong ncol,
|
||||
float missing,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
data::SimpleCSRSource& mat = *source;
|
||||
bool nan_missing = common::CheckNAN(missing);
|
||||
mat.info.num_row = nrow;
|
||||
mat.info.num_col = ncol;
|
||||
for (bst_ulong i = 0; i < nrow; ++i, data += ncol) {
|
||||
bst_ulong nelem = 0;
|
||||
for (bst_ulong j = 0; j < ncol; ++j) {
|
||||
if (common::CheckNAN(data[j])) {
|
||||
CHECK(nan_missing)
|
||||
<< "There are NAN in the matrix, however, you did not set missing=NAN";
|
||||
} else {
|
||||
if (nan_missing || data[j] != missing) {
|
||||
mat.row_data_.push_back(RowBatch::Entry(j, data[j]));
|
||||
++nelem;
|
||||
}
|
||||
}
|
||||
}
|
||||
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
|
||||
}
|
||||
mat.info.num_nonzero = mat.row_data_.size();
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
const int* idxset,
|
||||
bst_ulong len,
|
||||
DMatrixHandle* out) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
data::SimpleCSRSource src;
|
||||
src.CopyFrom(static_cast<DMatrix*>(handle));
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
|
||||
CHECK_EQ(src.info.group_ptr.size(), 0)
|
||||
<< "slice does not support group structure";
|
||||
|
||||
ret.Clear();
|
||||
ret.info.num_row = len;
|
||||
ret.info.num_col = src.info.num_col;
|
||||
|
||||
dmlc::DataIter<RowBatch>* iter = &src;
|
||||
iter->BeforeFirst();
|
||||
CHECK(iter->Next());
|
||||
|
||||
const RowBatch& batch = iter->Value();
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
const int ridx = idxset[i];
|
||||
RowBatch::Inst inst = batch[ridx];
|
||||
CHECK_LT(static_cast<bst_ulong>(ridx), batch.size);
|
||||
ret.row_data_.resize(ret.row_data_.size() + inst.length);
|
||||
std::memcpy(dmlc::BeginPtr(ret.row_data_) + ret.row_ptr_.back(), inst.data,
|
||||
sizeof(RowBatch::Entry) * inst.length);
|
||||
ret.row_ptr_.push_back(ret.row_ptr_.back() + inst.length);
|
||||
ret.info.num_nonzero += inst.length;
|
||||
|
||||
if (src.info.labels.size() != 0) {
|
||||
ret.info.labels.push_back(src.info.labels[ridx]);
|
||||
}
|
||||
if (src.info.weights.size() != 0) {
|
||||
ret.info.weights.push_back(src.info.weights[ridx]);
|
||||
}
|
||||
if (src.info.root_index.size() != 0) {
|
||||
ret.info.root_index.push_back(src.info.root_index[ridx]);
|
||||
}
|
||||
}
|
||||
*out = DMatrix::Create(std::move(source));
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixFree(DMatrixHandle handle) {
|
||||
API_BEGIN();
|
||||
delete static_cast<DMatrix*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixSaveBinary(DMatrixHandle handle,
|
||||
const char* fname,
|
||||
int silent) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->SaveToLocalFile(fname);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const float* info,
|
||||
bst_ulong len) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kFloat32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const char* field,
|
||||
const unsigned* info,
|
||||
bst_ulong len) {
|
||||
API_BEGIN();
|
||||
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kUInt32, len);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
bst_ulong len) {
|
||||
API_BEGIN();
|
||||
DMatrix *pmat = static_cast<DMatrix*>(handle);
|
||||
MetaInfo& info = pmat->info();
|
||||
info.group_ptr.resize(len + 1);
|
||||
info.group_ptr[0] = 0;
|
||||
for (uint64_t i = 0; i < len; ++i) {
|
||||
info.group_ptr[i + 1] = info.group_ptr[i] + group[i];
|
||||
}
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
||||
const char* field,
|
||||
bst_ulong* out_len,
|
||||
const float** out_dptr) {
|
||||
API_BEGIN();
|
||||
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
|
||||
const std::vector<float>* vec = nullptr;
|
||||
if (!std::strcmp(field, "label")) {
|
||||
vec = &info.labels;
|
||||
} else if (!std::strcmp(field, "weight")) {
|
||||
vec = &info.weights;
|
||||
} else if (!std::strcmp(field, "base_margin")) {
|
||||
vec = &info.base_margin;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown float field name " << field;
|
||||
}
|
||||
*out_len = static_cast<bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
const char *field,
|
||||
bst_ulong *out_len,
|
||||
const unsigned **out_dptr) {
|
||||
API_BEGIN();
|
||||
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "root_index")) {
|
||||
vec = &info.root_index;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown uint field name " << field;
|
||||
}
|
||||
*out_len = static_cast<bst_ulong>(vec->size());
|
||||
*out_dptr = dmlc::BeginPtr(*vec);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<bst_ulong>(static_cast<const DMatrix*>(handle)->info().num_row);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
*out = static_cast<size_t>(static_cast<const DMatrix*>(handle)->info().num_col);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// xgboost implementation
|
||||
int XGBoosterCreate(DMatrixHandle dmats[],
|
||||
bst_ulong len,
|
||||
BoosterHandle *out) {
|
||||
API_BEGIN();
|
||||
std::vector<DMatrix*> mats;
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
mats.push_back(static_cast<DMatrix*>(dmats[i]));
|
||||
}
|
||||
*out = new Booster(mats);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterFree(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
delete static_cast<Booster*>(handle);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *name,
|
||||
const char *value) {
|
||||
API_BEGIN();
|
||||
static_cast<Booster*>(handle)->SetParam(name, value);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dtrain) {
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
DMatrix *dtr = static_cast<DMatrix*>(dtrain);
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->UpdateOneIter(iter, dtr);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
DMatrixHandle dtrain,
|
||||
float *grad,
|
||||
float *hess,
|
||||
bst_ulong len) {
|
||||
std::vector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
DMatrix* dtr = static_cast<DMatrix*>(dtrain);
|
||||
tmp_gpair.resize(len);
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
tmp_gpair[i] = bst_gpair(grad[i], hess[i]);
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
bst->learner()->BoostOneIter(0, dtr, &tmp_gpair);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dmats[],
|
||||
const char* evnames[],
|
||||
bst_ulong len,
|
||||
const char** out_str) {
|
||||
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
Booster* bst = static_cast<Booster*>(handle);
|
||||
std::vector<DMatrix*> data_sets;
|
||||
std::vector<std::string> data_names;
|
||||
|
||||
for (bst_ulong i = 0; i < len; ++i) {
|
||||
data_sets.push_back(static_cast<DMatrix*>(dmats[i]));
|
||||
data_names.push_back(std::string(evnames[i]));
|
||||
}
|
||||
|
||||
bst->LazyInit();
|
||||
eval_str = bst->learner()->EvalOneIter(iter, data_sets, data_names);
|
||||
*out_str = eval_str.c_str();
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterPredict(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
int option_mask,
|
||||
unsigned ntree_limit,
|
||||
bst_ulong *len,
|
||||
const float **out_result) {
|
||||
std::vector<float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Predict(
|
||||
static_cast<DMatrix*>(dmat),
|
||||
(option_mask & 1) != 0,
|
||||
&preds, ntree_limit,
|
||||
(option_mask & 2) != 0);
|
||||
*out_result = dmlc::BeginPtr(preds);
|
||||
*len = static_cast<bst_ulong>(preds.size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Booster*>(handle)->LoadModel(fi.get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Save(fo.get());
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
const void* buf,
|
||||
bst_ulong len) {
|
||||
API_BEGIN();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Booster*>(handle)->LoadModel(&fs);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
bst_ulong* out_len,
|
||||
const char** out_dptr) {
|
||||
std::string& raw_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
raw_str.resize(0);
|
||||
|
||||
API_BEGIN();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Save(&fo);
|
||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||
*out_len = static_cast<bst_ulong>(raw_str.length());
|
||||
API_END();
|
||||
}
|
||||
|
||||
inline void XGBoostDumpModelImpl(
|
||||
BoosterHandle handle,
|
||||
const FeatureMap& fmap,
|
||||
int with_stats,
|
||||
bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
std::vector<std::string>& str_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_str;
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
Booster *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
str_vecs = bst->learner()->Dump2Text(fmap, with_stats != 0);
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
charp_vecs[i] = str_vecs[i].c_str();
|
||||
}
|
||||
*out_models = dmlc::BeginPtr(charp_vecs);
|
||||
*len = static_cast<bst_ulong>(charp_vecs.size());
|
||||
}
|
||||
int XGBoosterDumpModel(BoosterHandle handle,
|
||||
const char* fmap,
|
||||
int with_stats,
|
||||
bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
FeatureMap featmap;
|
||||
if (strlen(fmap) != 0) {
|
||||
std::unique_ptr<dmlc::Stream> fs(
|
||||
dmlc::Stream::Create(fmap, "r"));
|
||||
dmlc::istream is(fs.get());
|
||||
featmap.LoadText(is);
|
||||
}
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
|
||||
int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||
int fnum,
|
||||
const char** fname,
|
||||
const char** ftype,
|
||||
int with_stats,
|
||||
bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
FeatureMap featmap;
|
||||
for (int i = 0; i < fnum; ++i) {
|
||||
featmap.PushBack(i, fname[i], ftype[i]);
|
||||
}
|
||||
XGBoostDumpModelImpl(handle, featmap, with_stats, len, out_models);
|
||||
API_END();
|
||||
}
|
||||
21
src/c_api/c_api_error.cc
Normal file
21
src/c_api/c_api_error.cc
Normal file
@@ -0,0 +1,21 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file c_api_error.cc
|
||||
* \brief C error handling
|
||||
*/
|
||||
#include "./c_api_error.h"
|
||||
#include "../common/thread_local.h"
|
||||
|
||||
struct XGBAPIErrorEntry {
|
||||
std::string last_error;
|
||||
};
|
||||
|
||||
typedef xgboost::common::ThreadLocalStore<XGBAPIErrorEntry> XGBAPIErrorStore;
|
||||
|
||||
const char *XGBGetLastError() {
|
||||
return XGBAPIErrorStore::Get()->last_error.c_str();
|
||||
}
|
||||
|
||||
void XGBAPISetLastError(const char* msg) {
|
||||
XGBAPIErrorStore::Get()->last_error = msg;
|
||||
}
|
||||
39
src/c_api/c_api_error.h
Normal file
39
src/c_api/c_api_error.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file c_api_error.h
|
||||
* \brief Error handling for C API.
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_ERROR_H_
|
||||
#define XGBOOST_C_API_C_API_ERROR_H_
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
/*! \brief macro to guard beginning and end section of all functions */
|
||||
#define API_BEGIN() try {
|
||||
/*! \brief every function starts with API_BEGIN();
|
||||
and finishes with API_END() or API_END_HANDLE_ERROR */
|
||||
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||
/*!
|
||||
* \brief every function starts with API_BEGIN();
|
||||
* and finishes with API_END() or API_END_HANDLE_ERROR
|
||||
* The finally clause contains procedure to cleanup states when an error happens.
|
||||
*/
|
||||
#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &_except_) { Finalize; return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||
|
||||
/*!
|
||||
* \brief Set the last error message needed by C API
|
||||
* \param msg The error message to set.
|
||||
*/
|
||||
void XGBAPISetLastError(const char* msg);
|
||||
/*!
|
||||
* \brief handle exception throwed out
|
||||
* \param e the exception
|
||||
* \return the return value of API after exception is handled
|
||||
*/
|
||||
inline int XGBAPIHandleException(const dmlc::Error &e) {
|
||||
XGBAPISetLastError(e.what());
|
||||
return -1;
|
||||
}
|
||||
#endif // XGBOOST_C_API_C_API_ERROR_H_
|
||||
@@ -11,8 +11,9 @@
|
||||
|
||||
#include <xgboost/learner.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <dmlc/timer.h>
|
||||
#include <iomanip>
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
#include <cstdio>
|
||||
@@ -107,6 +108,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
.describe("Data split mode.");
|
||||
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
|
||||
.describe("Number of trees used for prediction, 0 means use all trees.");
|
||||
DMLC_DECLARE_FIELD(pred_margin).set_default(false)
|
||||
.describe("Whether to predict margin value instead of probability.");
|
||||
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
|
||||
.describe("Whether dump the model statistics.");
|
||||
DMLC_DECLARE_FIELD(name_fmap).set_default("NULL")
|
||||
@@ -115,7 +118,8 @@ struct CLIParam : public dmlc::Parameter<CLIParam> {
|
||||
.describe("Name of the output dump text file.");
|
||||
// alias
|
||||
DMLC_DECLARE_ALIAS(train_path, data);
|
||||
DMLC_DECLARE_ALIAS(test_path, "test:data");
|
||||
DMLC_DECLARE_ALIAS(test_path, test:data);
|
||||
DMLC_DECLARE_ALIAS(name_fmap, fmap);
|
||||
}
|
||||
// customized configure function of CLIParam
|
||||
inline void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) {
|
||||
@@ -149,7 +153,7 @@ DMLC_REGISTER_PARAMETER(CLIParam);
|
||||
void CLITrain(const CLIParam& param) {
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
LOG(INFO) << "start " << pname << ":" << rabit::GetRank();
|
||||
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
|
||||
}
|
||||
// load in data.
|
||||
std::unique_ptr<DMatrix> dtrain(
|
||||
@@ -178,6 +182,8 @@ void CLITrain(const CLIParam& param) {
|
||||
std::unique_ptr<dmlc::Stream> fi(
|
||||
dmlc::Stream::Create(param.model_in.c_str(), "r"));
|
||||
learner->Load(fi.get());
|
||||
} else {
|
||||
learner->InitModel();
|
||||
}
|
||||
}
|
||||
// start training.
|
||||
@@ -186,7 +192,7 @@ void CLITrain(const CLIParam& param) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
if (param.silent == 0) {
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
LOG(CONSOLE) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
}
|
||||
learner->UpdateOneIter(i, dtrain.get());
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
@@ -200,16 +206,18 @@ void CLITrain(const CLIParam& param) {
|
||||
std::string res = learner->EvalOneIter(i, eval_datasets, eval_data_names);
|
||||
if (rabit::IsDistributed()) {
|
||||
if (rabit::GetRank() == 0) {
|
||||
rabit::TrackerPrint(res + "\n");
|
||||
LOG(TRACKER) << res;
|
||||
}
|
||||
} else {
|
||||
if (param.silent < 2) {
|
||||
LOG(INFO) << res;
|
||||
LOG(CONSOLE) << res;
|
||||
}
|
||||
}
|
||||
if (param.save_period != 0 && (i + 1) % param.save_period == 0) {
|
||||
std::ostringstream os;
|
||||
os << param.model_dir << '/' << i + 1 << ".model";
|
||||
os << param.model_dir << '/'
|
||||
<< std::setfill('0') << std::setw(4)
|
||||
<< i + 1 << ".model";
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(os.str().c_str(), "w"));
|
||||
learner->Save(fo.get());
|
||||
@@ -228,7 +236,9 @@ void CLITrain(const CLIParam& param) {
|
||||
param.model_out != "NONE") {
|
||||
std::ostringstream os;
|
||||
if (param.model_out == "NULL") {
|
||||
os << param.model_dir << '/' << param.num_round << ".model";
|
||||
os << param.model_dir << '/'
|
||||
<< std::setfill('0') << std::setw(4)
|
||||
<< param.num_round << ".model";
|
||||
} else {
|
||||
os << param.model_out;
|
||||
}
|
||||
@@ -239,7 +249,7 @@ void CLITrain(const CLIParam& param) {
|
||||
|
||||
if (param.silent == 0) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
LOG(INFO) << "update end, " << elapsed << " sec in all";
|
||||
LOG(CONSOLE) << "update end, " << elapsed << " sec in all";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,6 +282,8 @@ void CLIDump2Text(const CLIParam& param) {
|
||||
}
|
||||
|
||||
void CLIPredict(const CLIParam& param) {
|
||||
CHECK_NE(param.test_path, "NULL")
|
||||
<< "Test dataset parameter test:data must be specified.";
|
||||
// load data
|
||||
std::unique_ptr<DMatrix> dtest(
|
||||
DMatrix::Load(param.test_path, param.silent != 0, param.dsplit == 2));
|
||||
@@ -284,12 +296,12 @@ void CLIPredict(const CLIParam& param) {
|
||||
learner->Load(fi.get());
|
||||
|
||||
if (param.silent == 0) {
|
||||
LOG(INFO) << "start prediction...";
|
||||
LOG(CONSOLE) << "start prediction...";
|
||||
}
|
||||
std::vector<float> preds;
|
||||
learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit);
|
||||
if (param.silent == 0) {
|
||||
LOG(INFO) << "writing prediction to " << param.name_pred;
|
||||
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
|
||||
}
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(param.name_pred.c_str(), "w"));
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#ifndef XGBOOST_COMMON_BASE64_H_
|
||||
#define XGBOOST_COMMON_BASE64_H_
|
||||
|
||||
#include <dmlc/logging.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <cctype>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
15
src/common/common.cc
Normal file
15
src/common/common.cc
Normal file
@@ -0,0 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file common.cc
|
||||
* \brief Enable all kinds of global variables in common.
|
||||
*/
|
||||
#include "./random.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
RandomEngine& GlobalRandom() {
|
||||
static RandomEngine inst;
|
||||
return inst;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
@@ -8,7 +8,7 @@
|
||||
#define XGBOOST_COMMON_QUANTILE_H_
|
||||
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
|
||||
77
src/common/thread_local.h
Normal file
77
src/common/thread_local.h
Normal file
@@ -0,0 +1,77 @@
|
||||
/*!
|
||||
* Copyright (c) 2015 by Contributors
|
||||
* \file thread_local.h
|
||||
* \brief Common utility for thread local storage.
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_THREAD_LOCAL_H_
|
||||
#define XGBOOST_COMMON_THREAD_LOCAL_H_
|
||||
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
// macro hanlding for threadlocal variables
|
||||
#ifdef __GNUC__
|
||||
#define MX_TREAD_LOCAL __thread
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
#define MX_TREAD_LOCAL _Thread_local
|
||||
#elif defined(_MSC_VER)
|
||||
#define MX_TREAD_LOCAL __declspec(thread)
|
||||
#endif
|
||||
|
||||
#ifndef MX_TREAD_LOCAL
|
||||
#message("Warning: Threadlocal is not enabled");
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief A threadlocal store to store threadlocal variables.
|
||||
* Will return a thread local singleton of type T
|
||||
* \tparam T the type we like to store
|
||||
*/
|
||||
template<typename T>
|
||||
class ThreadLocalStore {
|
||||
public:
|
||||
/*! \return get a thread local singleton */
|
||||
static T* Get() {
|
||||
static MX_TREAD_LOCAL T* ptr = nullptr;
|
||||
if (ptr == nullptr) {
|
||||
ptr = new T();
|
||||
Singleton()->RegisterDelete(ptr);
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief constructor */
|
||||
ThreadLocalStore() {}
|
||||
/*! \brief destructor */
|
||||
~ThreadLocalStore() {
|
||||
for (size_t i = 0; i < data_.size(); ++i) {
|
||||
delete data_[i];
|
||||
}
|
||||
}
|
||||
/*! \return singleton of the store */
|
||||
static ThreadLocalStore<T> *Singleton() {
|
||||
static ThreadLocalStore<T> inst;
|
||||
return &inst;
|
||||
}
|
||||
/*!
|
||||
* \brief register str for internal deletion
|
||||
* \param str the string pointer
|
||||
*/
|
||||
void RegisterDelete(T *str) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
data_.push_back(str);
|
||||
lock.unlock();
|
||||
}
|
||||
/*! \brief internal mutex */
|
||||
std::mutex mutex_;
|
||||
/*!\brief internal data */
|
||||
std::vector<T*> data_;
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_THREAD_LOCAL_H_
|
||||
@@ -3,7 +3,12 @@
|
||||
* \file data.cc
|
||||
*/
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <cstring>
|
||||
#include "./sparse_batch_page.h"
|
||||
#include "./simple_dmatrix.h"
|
||||
#include "./simple_csr_source.h"
|
||||
#include "../common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
// implementation of inline functions
|
||||
@@ -83,4 +88,83 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri,
|
||||
bool silent,
|
||||
bool load_row_split,
|
||||
const std::string& file_format) {
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
cache_file = uri.substr(dlm_pos + 1, uri.length());
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
CHECK_EQ(cache_file.find('#'), std::string::npos)
|
||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||
if (load_row_split) {
|
||||
std::ostringstream os;
|
||||
os << cache_file << ".r" << rabit::GetRank();
|
||||
cache_file = os.str();
|
||||
}
|
||||
} else {
|
||||
fname = uri;
|
||||
}
|
||||
int partid = 0, npart = 1;
|
||||
if (load_row_split) {
|
||||
partid = rabit::GetRank();
|
||||
npart = rabit::GetWorldSize();
|
||||
}
|
||||
|
||||
// legacy handling of binary data loading
|
||||
if (file_format == "auto" && !load_row_split) {
|
||||
int magic;
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
|
||||
common::PeekableInStream is(fi.get());
|
||||
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
|
||||
magic == data::SimpleCSRSource::kMagic) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
source->LoadBinary(&is);
|
||||
DMatrix* dmat = DMatrix::Create(std::move(source), cache_file);
|
||||
if (!silent) {
|
||||
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
||||
<< dmat->info().num_nonzero << " entries loaded from " << uri;
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
|
||||
std::string ftype = file_format;
|
||||
if (file_format == "auto") ftype = "libsvm";
|
||||
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, ftype.c_str()));
|
||||
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
|
||||
if (!silent) {
|
||||
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "
|
||||
<< dmat->info().num_nonzero << " entries loaded from " << uri;
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
|
||||
DMatrix* DMatrix::Create(dmlc::Parser<uint32_t>* parser,
|
||||
const std::string& cache_prefix) {
|
||||
if (cache_prefix.length() == 0) {
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
source->CopyFrom(parser);
|
||||
return DMatrix::Create(std::move(source), cache_prefix);
|
||||
} else {
|
||||
LOG(FATAL) << "external memory not yet implemented";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void DMatrix::SaveToLocalFile(const std::string& fname) {
|
||||
data::SimpleCSRSource source;
|
||||
source.CopyFrom(this);
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
source.SaveBinary(fo.get());
|
||||
}
|
||||
|
||||
DMatrix* DMatrix::Create(std::unique_ptr<DataSource>&& source,
|
||||
const std::string& cache_prefix) {
|
||||
return new data::SimpleDMatrix(std::move(source));
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
* \file simple_csr_source.cc
|
||||
*/
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include "./simple_csr_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -80,7 +80,7 @@ void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
|
||||
}
|
||||
|
||||
void SimpleCSRSource::BeforeFirst() {
|
||||
at_first_ = false;
|
||||
at_first_ = true;
|
||||
}
|
||||
|
||||
bool SimpleCSRSource::Next() {
|
||||
|
||||
265
src/data/simple_dmatrix.cc
Normal file
265
src/data/simple_dmatrix.cc
Normal file
@@ -0,0 +1,265 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* \file simple_dmatrix.cc
|
||||
* \brief the input data structure for gradient boosting
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/data.h>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "./simple_dmatrix.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/group_data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
bool SimpleDMatrix::ColBatchIter::Next() {
|
||||
if (data_ptr_ >= cpages_.size()) return false;
|
||||
data_ptr_ += 1;
|
||||
SparsePage* pcol = cpages_[data_ptr_ - 1].get();
|
||||
batch_.size = col_index_.size();
|
||||
col_data_.resize(col_index_.size(), SparseBatch::Inst(NULL, 0));
|
||||
for (size_t i = 0; i < col_data_.size(); ++i) {
|
||||
const bst_uint ridx = col_index_[i];
|
||||
col_data_[i] = SparseBatch::Inst
|
||||
(dmlc::BeginPtr(pcol->data) + pcol->offset[ridx],
|
||||
static_cast<bst_uint>(pcol->offset[ridx + 1] - pcol->offset[ridx]));
|
||||
}
|
||||
batch_.col_index = dmlc::BeginPtr(col_index_);
|
||||
batch_.col_data = dmlc::BeginPtr(col_data_);
|
||||
return true;
|
||||
}
|
||||
|
||||
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator() {
|
||||
size_t ncol = this->info().num_col;
|
||||
col_iter_.col_index_.resize(ncol);
|
||||
for (size_t i = 0; i < ncol; ++i) {
|
||||
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
|
||||
}
|
||||
col_iter_.BeforeFirst();
|
||||
return &col_iter_;
|
||||
}
|
||||
|
||||
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator(const std::vector<bst_uint>&fset) {
|
||||
size_t ncol = this->info().num_col;
|
||||
col_iter_.col_index_.resize(0);
|
||||
for (size_t i = 0; i < fset.size(); ++i) {
|
||||
if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]);
|
||||
}
|
||||
col_iter_.BeforeFirst();
|
||||
return &col_iter_;
|
||||
}
|
||||
|
||||
void SimpleDMatrix::InitColAccess(const std::vector<bool> &enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch) {
|
||||
if (this->HaveColAccess()) return;
|
||||
|
||||
col_iter_.cpages_.clear();
|
||||
if (info().num_row < max_row_perbatch) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeOneBatch(enabled, pkeep, page.get());
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
} else {
|
||||
this->MakeManyBatch(enabled, pkeep, max_row_perbatch);
|
||||
}
|
||||
// setup col-size
|
||||
col_size_.resize(info().num_col);
|
||||
std::fill(col_size_.begin(), col_size_.end(), 0);
|
||||
for (size_t i = 0; i < col_iter_.cpages_.size(); ++i) {
|
||||
SparsePage *pcol = col_iter_.cpages_[i].get();
|
||||
for (size_t j = 0; j < pcol->Size(); ++j) {
|
||||
col_size_[j] += pcol->offset[j + 1] - pcol->offset[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// internal function to make one batch from row iter.
|
||||
void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
SparsePage *pcol) {
|
||||
// clear rowset
|
||||
buffered_rowset_.clear();
|
||||
// bit map
|
||||
int nthread;
|
||||
std::vector<bool> bmap;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
}
|
||||
|
||||
pcol->Clear();
|
||||
common::ParallelGroupBuilder<SparseBatch::Entry>
|
||||
builder(&pcol->offset, &pcol->data);
|
||||
builder.InitBudget(info().num_col, nthread);
|
||||
// start working
|
||||
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
bmap.resize(bmap.size() + batch.size, true);
|
||||
std::bernoulli_distribution coin_flip(pkeep);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
|
||||
long batch_size = static_cast<long>(batch.size); // NOLINT(*)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
if (pkeep == 1.0f || coin_flip(rnd)) {
|
||||
buffered_rowset_.push_back(ridx);
|
||||
} else {
|
||||
bmap[i] = false;
|
||||
}
|
||||
}
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
if (bmap[ridx]) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
if (enabled[inst[j].index]) {
|
||||
builder.AddBudget(inst[j].index, tid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < static_cast<long>(batch.size); ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
if (bmap[ridx]) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
if (enabled[inst[j].index]) {
|
||||
builder.Push(inst[j].index,
|
||||
SparseBatch::Entry((bst_uint)(batch.base_rowid+i),
|
||||
inst[j].fvalue), tid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CHECK_EQ(pcol->Size(), info().num_col);
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch) {
|
||||
size_t btop = 0;
|
||||
std::bernoulli_distribution coin_flip(pkeep);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
buffered_rowset_.clear();
|
||||
// internal temp cache
|
||||
SparsePage tmp; tmp.Clear();
|
||||
// start working
|
||||
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
if (pkeep == 1.0f || coin_flip(rnd)) {
|
||||
buffered_rowset_.push_back(ridx);
|
||||
tmp.Push(batch[i]);
|
||||
}
|
||||
if (tmp.Size() >= max_row_perbatch) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeColPage(tmp.GetRowBatch(0),
|
||||
dmlc::BeginPtr(buffered_rowset_) + btop,
|
||||
enabled, page.get());
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
btop = buffered_rowset_.size();
|
||||
tmp.Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tmp.Size() != 0) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeColPage(tmp.GetRowBatch(0),
|
||||
dmlc::BeginPtr(buffered_rowset_) + btop,
|
||||
enabled, page.get());
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
}
|
||||
}
|
||||
|
||||
// make column page from subset of rowbatchs
|
||||
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
||||
const bst_uint* ridx,
|
||||
const std::vector<bool>& enabled,
|
||||
SparsePage* pcol) {
|
||||
int nthread;
|
||||
#pragma omp parallel
|
||||
{
|
||||
nthread = omp_get_num_threads();
|
||||
int max_nthread = std::max(omp_get_num_procs() / 2 - 2, 1);
|
||||
if (nthread > max_nthread) {
|
||||
nthread = max_nthread;
|
||||
}
|
||||
}
|
||||
pcol->Clear();
|
||||
common::ParallelGroupBuilder<SparseBatch::Entry>
|
||||
builder(&pcol->offset, &pcol->data);
|
||||
builder.InitBudget(info().num_col, nthread);
|
||||
bst_omp_uint ndata = static_cast<bst_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
RowBatch::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
const SparseBatch::Entry &e = inst[j];
|
||||
if (enabled[e.index]) {
|
||||
builder.AddBudget(e.index, tid);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
int tid = omp_get_thread_num();
|
||||
RowBatch::Inst inst = batch[i];
|
||||
for (bst_uint j = 0; j < inst.length; ++j) {
|
||||
const SparseBatch::Entry &e = inst[j];
|
||||
builder.Push(e.index,
|
||||
SparseBatch::Entry(ridx[i], e.fvalue),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(pcol->Size(), info().num_col);
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool SimpleDMatrix::SingleColBlock() const {
|
||||
return col_iter_.cpages_.size() <= 1;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
119
src/data/simple_dmatrix.h
Normal file
119
src/data/simple_dmatrix.h
Normal file
@@ -0,0 +1,119 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file simple_dmatrix.h
|
||||
* \brief In-memory version of DMatrix.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_SIMPLE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include "./sparse_batch_page.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class SimpleDMatrix : public DMatrix {
|
||||
public:
|
||||
explicit SimpleDMatrix(std::unique_ptr<DataSource>&& source)
|
||||
: source_(std::move(source)) {}
|
||||
|
||||
MetaInfo& info() override {
|
||||
return source_->info;
|
||||
}
|
||||
|
||||
const MetaInfo& info() const override {
|
||||
return source_->info;
|
||||
}
|
||||
|
||||
dmlc::DataIter<RowBatch>* RowIterator() override {
|
||||
dmlc::DataIter<RowBatch>* iter = source_.get();
|
||||
iter->BeforeFirst();
|
||||
return iter;
|
||||
}
|
||||
|
||||
bool HaveColAccess() const override {
|
||||
return col_size_.size() != 0;
|
||||
}
|
||||
|
||||
const std::vector<bst_uint>& buffered_rowset() const override {
|
||||
return buffered_rowset_;
|
||||
}
|
||||
|
||||
size_t GetColSize(size_t cidx) const {
|
||||
return col_size_[cidx];
|
||||
}
|
||||
|
||||
float GetColDensity(size_t cidx) const override {
|
||||
size_t nmiss = buffered_rowset_.size() - col_size_[cidx];
|
||||
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
|
||||
}
|
||||
|
||||
dmlc::DataIter<ColBatch>* ColIterator() override;
|
||||
|
||||
dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) override;
|
||||
|
||||
void InitColAccess(const std::vector<bool>& enabled,
|
||||
float subsample,
|
||||
size_t max_row_perbatch) override;
|
||||
|
||||
bool SingleColBlock() const override;
|
||||
|
||||
private:
|
||||
// in-memory column batch iterator.
|
||||
struct ColBatchIter: dmlc::DataIter<ColBatch> {
|
||||
public:
|
||||
ColBatchIter() : data_ptr_(0) {}
|
||||
void BeforeFirst() override {
|
||||
data_ptr_ = 0;
|
||||
}
|
||||
const ColBatch &Value() const override {
|
||||
return batch_;
|
||||
}
|
||||
bool Next() override;
|
||||
|
||||
private:
|
||||
// allow SimpleDMatrix to access it.
|
||||
friend class SimpleDMatrix;
|
||||
// data content
|
||||
std::vector<bst_uint> col_index_;
|
||||
// column content
|
||||
std::vector<ColBatch::Inst> col_data_;
|
||||
// column sparse pages
|
||||
std::vector<std::unique_ptr<SparsePage> > cpages_;
|
||||
// data pointer
|
||||
size_t data_ptr_;
|
||||
// temporal space for batch
|
||||
ColBatch batch_;
|
||||
};
|
||||
|
||||
// source data pointer.
|
||||
std::unique_ptr<DataSource> source_;
|
||||
// column iterator
|
||||
ColBatchIter col_iter_;
|
||||
// list of row index that are buffered.
|
||||
std::vector<bst_uint> buffered_rowset_;
|
||||
/*! \brief sizeof column data */
|
||||
std::vector<size_t> col_size_;
|
||||
|
||||
// internal function to make one batch from row iter.
|
||||
void MakeOneBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
SparsePage *pcol);
|
||||
|
||||
void MakeManyBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch);
|
||||
|
||||
void MakeColPage(const RowBatch& batch,
|
||||
const bst_uint* ridx,
|
||||
const std::vector<bool>& enabled,
|
||||
SparsePage* pcol);
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_H_
|
||||
214
src/data/sparse_batch_page.h
Normal file
214
src/data/sparse_batch_page.h
Normal file
@@ -0,0 +1,214 @@
|
||||
/*!
|
||||
* Copyright (c) 2014 by Contributors
|
||||
* \file sparse_batch_page.h
|
||||
* content holder of sparse batch that can be saved to disk
|
||||
* the representation can be effectively
|
||||
* use in external memory computation
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
|
||||
#define XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
|
||||
|
||||
#include <xgboost/data.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
/*!
|
||||
* \brief in-memory storage unit of sparse batch
|
||||
*/
|
||||
class SparsePage {
|
||||
public:
|
||||
/*! \brief offset of the segments */
|
||||
std::vector<size_t> offset;
|
||||
/*! \brief the data of the segments */
|
||||
std::vector<SparseBatch::Entry> data;
|
||||
|
||||
/*! \brief constructor */
|
||||
SparsePage() {
|
||||
this->Clear();
|
||||
}
|
||||
/*! \return number of instance in the page */
|
||||
inline size_t Size() const {
|
||||
return offset.size() - 1;
|
||||
}
|
||||
/*!
|
||||
* \brief load only the segments we are interested in
|
||||
* \param fi the input stream of the file
|
||||
* \param sorted_index_set sorted index of segments we are interested in
|
||||
* \return true of the loading as successful, false if end of file was reached
|
||||
*/
|
||||
inline bool Load(dmlc::SeekStream *fi,
|
||||
const std::vector<bst_uint> &sorted_index_set) {
|
||||
if (!fi->Read(&disk_offset_)) return false;
|
||||
// setup the offset
|
||||
offset.clear(); offset.push_back(0);
|
||||
for (size_t i = 0; i < sorted_index_set.size(); ++i) {
|
||||
bst_uint fid = sorted_index_set[i];
|
||||
CHECK_LT(fid + 1, disk_offset_.size());
|
||||
size_t size = disk_offset_[fid + 1] - disk_offset_[fid];
|
||||
offset.push_back(offset.back() + size);
|
||||
}
|
||||
data.resize(offset.back());
|
||||
// read in the data
|
||||
size_t begin = fi->Tell();
|
||||
size_t curr_offset = 0;
|
||||
for (size_t i = 0; i < sorted_index_set.size();) {
|
||||
bst_uint fid = sorted_index_set[i];
|
||||
if (disk_offset_[fid] != curr_offset) {
|
||||
CHECK_GT(disk_offset_[fid], curr_offset);
|
||||
fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry));
|
||||
curr_offset = disk_offset_[fid];
|
||||
}
|
||||
size_t j, size_to_read = 0;
|
||||
for (j = i; j < sorted_index_set.size(); ++j) {
|
||||
if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) {
|
||||
size_to_read += offset[j + 1] - offset[j];
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (size_to_read != 0) {
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(data) + offset[i],
|
||||
size_to_read * sizeof(SparseBatch::Entry)),
|
||||
size_to_read * sizeof(SparseBatch::Entry))
|
||||
<< "Invalid SparsePage file";
|
||||
curr_offset += size_to_read;
|
||||
}
|
||||
i = j;
|
||||
}
|
||||
// seek to end of record
|
||||
if (curr_offset != disk_offset_.back()) {
|
||||
fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief load all the segments
|
||||
* \param fi the input stream of the file
|
||||
* \return true of the loading as successful, false if end of file was reached
|
||||
*/
|
||||
inline bool Load(dmlc::Stream *fi) {
|
||||
if (!fi->Read(&offset)) return false;
|
||||
CHECK_NE(offset.size(), 0) << "Invalid SparsePage file";
|
||||
data.resize(offset.back());
|
||||
if (data.size() != 0) {
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(data), data.size() * sizeof(SparseBatch::Entry)),
|
||||
data.size() * sizeof(SparseBatch::Entry))
|
||||
<< "Invalid SparsePage file";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief save the data to fo, when a page was written
|
||||
* to disk it must contain all the elements in the
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void Save(dmlc::Stream *fo) const {
|
||||
CHECK(offset.size() != 0 && offset[0] == 0);
|
||||
CHECK_EQ(offset.back(), data.size());
|
||||
fo->Write(offset);
|
||||
if (data.size() != 0) {
|
||||
fo->Write(dmlc::BeginPtr(data), data.size() * sizeof(SparseBatch::Entry));
|
||||
}
|
||||
}
|
||||
/*! \return estimation of memory cost of this page */
|
||||
inline size_t MemCostBytes(void) const {
|
||||
return offset.size() * sizeof(size_t) + data.size() * sizeof(SparseBatch::Entry);
|
||||
}
|
||||
/*! \brief clear the page */
|
||||
inline void Clear(void) {
|
||||
offset.clear();
|
||||
offset.push_back(0);
|
||||
data.clear();
|
||||
}
|
||||
/*!
|
||||
* \brief load all the segments and add it to existing batch
|
||||
* \param fi the input stream of the file
|
||||
* \return true of the loading as successful, false if end of file was reached
|
||||
*/
|
||||
inline bool PushLoad(dmlc::Stream *fi) {
|
||||
if (!fi->Read(&disk_offset_)) return false;
|
||||
data.resize(offset.back() + disk_offset_.back());
|
||||
if (disk_offset_.back() != 0) {
|
||||
CHECK_EQ(fi->Read(dmlc::BeginPtr(data) + offset.back(),
|
||||
disk_offset_.back() * sizeof(SparseBatch::Entry)),
|
||||
disk_offset_.back() * sizeof(SparseBatch::Entry))
|
||||
<< "Invalid SparsePage file";
|
||||
}
|
||||
size_t top = offset.back();
|
||||
size_t begin = offset.size();
|
||||
offset.resize(offset.size() + disk_offset_.size());
|
||||
for (size_t i = 0; i < disk_offset_.size(); ++i) {
|
||||
offset[i + begin] = top + disk_offset_[i];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
/*!
|
||||
* \brief Push row batch into the page
|
||||
* \param batch the row batch
|
||||
*/
|
||||
inline void Push(const RowBatch &batch) {
|
||||
data.resize(offset.back() + batch.ind_ptr[batch.size]);
|
||||
std::memcpy(dmlc::BeginPtr(data) + offset.back(),
|
||||
batch.data_ptr + batch.ind_ptr[0],
|
||||
sizeof(SparseBatch::Entry) * batch.ind_ptr[batch.size]);
|
||||
size_t top = offset.back();
|
||||
size_t begin = offset.size();
|
||||
offset.resize(offset.size() + batch.size);
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
offset[i + begin] = top + batch.ind_ptr[i + 1] - batch.ind_ptr[0];
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Push a sparse page
|
||||
* \param batch the row page
|
||||
*/
|
||||
inline void Push(const SparsePage &batch) {
|
||||
size_t top = offset.back();
|
||||
data.resize(top + batch.data.size());
|
||||
std::memcpy(dmlc::BeginPtr(data) + top,
|
||||
dmlc::BeginPtr(batch.data),
|
||||
sizeof(SparseBatch::Entry) * batch.data.size());
|
||||
size_t begin = offset.size();
|
||||
offset.resize(begin + batch.Size());
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
offset[i + begin] = top + batch.offset[i + 1];
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief Push one instance into page
|
||||
* \param row an instance row
|
||||
*/
|
||||
inline void Push(const SparseBatch::Inst &inst) {
|
||||
offset.push_back(offset.back() + inst.length);
|
||||
size_t begin = data.size();
|
||||
data.resize(begin + inst.length);
|
||||
if (inst.length != 0) {
|
||||
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
|
||||
sizeof(SparseBatch::Entry) * inst.length);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \param base_rowid base_rowid of the data
|
||||
* \return row batch representation of the page
|
||||
*/
|
||||
inline RowBatch GetRowBatch(size_t base_rowid) const {
|
||||
RowBatch out;
|
||||
out.base_rowid = base_rowid;
|
||||
out.ind_ptr = dmlc::BeginPtr(offset);
|
||||
out.data_ptr = dmlc::BeginPtr(data);
|
||||
out.size = offset.size() - 1;
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief external memory column offset */
|
||||
std::vector<size_t> disk_offset_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
|
||||
@@ -5,10 +5,10 @@
|
||||
* the update rule is parallel coordinate descent (shotgun)
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -17,6 +17,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gblinear);
|
||||
|
||||
// model parameter
|
||||
struct GBLinearModelParam :public dmlc::Parameter<GBLinearModelParam> {
|
||||
// number of feature dimension
|
||||
@@ -168,6 +171,9 @@ class GBLinear : public GradientBooster {
|
||||
int64_t buffer_offset,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
if (model.weight.size() == 0) {
|
||||
model.InitModel();
|
||||
}
|
||||
CHECK_EQ(ntree_limit, 0)
|
||||
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
|
||||
std::vector<float> &preds = *out_preds;
|
||||
@@ -293,4 +299,3 @@ XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
29
src/gbm/gbm.cc
Normal file
29
src/gbm/gbm.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file gbm.cc
|
||||
* \brief Registry of gradient boosters.
|
||||
*/
|
||||
#include <xgboost/gbm.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
GradientBooster* GradientBooster::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown gbm type " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(gblinear);
|
||||
DMLC_REGISTRY_LINK_TAG(gbtree);
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
@@ -4,9 +4,9 @@
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||
|
||||
/*! \brief training parameters */
|
||||
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
/*! \brief number of threads */
|
||||
@@ -482,4 +484,3 @@ XGBOOST_REGISTER_GBM(GBTree, "gbtree")
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file global.cc
|
||||
* \brief Enable all kinds of global static registry and variables.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <xgboost/gbm.h>
|
||||
#include "./common/random.h"
|
||||
#include "./common/base64.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
// implement factory functions
|
||||
ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
Metric* Metric::Create(const std::string& name) {
|
||||
std::string buf = name;
|
||||
std::string prefix = name;
|
||||
auto pos = buf.find('@');
|
||||
if (pos == std::string::npos) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)(nullptr);
|
||||
} else {
|
||||
std::string prefix = buf.substr(0, pos);
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)(buf.substr(pos + 1, buf.length()).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
GradientBooster* GradientBooster::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown gbm type " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
namespace common {
|
||||
RandomEngine& GlobalRandom() {
|
||||
static RandomEngine inst;
|
||||
return inst;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <iomanip>
|
||||
#include "./common/io.h"
|
||||
#include "./common/random.h"
|
||||
|
||||
@@ -94,6 +95,9 @@ struct LearnerTrainParam
|
||||
}
|
||||
};
|
||||
|
||||
DMLC_REGISTER_PARAMETER(LearnerModelParam);
|
||||
DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
||||
|
||||
/*!
|
||||
* \brief learner that performs gradient boosting for a specific objective function.
|
||||
* It does training and prediction.
|
||||
@@ -144,6 +148,9 @@ class LearnerImpl : public Learner {
|
||||
|
||||
if (cfg_.count("num_class") != 0) {
|
||||
cfg_["num_output_group"] = cfg_["num_class"];
|
||||
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
||||
cfg_["objective"] = "multi:softmax";
|
||||
}
|
||||
}
|
||||
|
||||
if (cfg_.count("max_delta_step") == 0 &&
|
||||
@@ -187,6 +194,10 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
}
|
||||
|
||||
void InitModel() override {
|
||||
this->LazyInitModel();
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
// TODO(tqchen) mark deprecation of old format.
|
||||
common::PeekableInStream fp(fi);
|
||||
@@ -202,7 +213,6 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
// use the peekable reader.
|
||||
fi = &fp;
|
||||
std::string name_gbm, name_obj;
|
||||
// read parameter
|
||||
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
|
||||
<< "BoostLearner: wrong model format";
|
||||
@@ -218,7 +228,7 @@ class LearnerImpl : public Learner {
|
||||
len = len >> static_cast<uint64_t>(32UL);
|
||||
}
|
||||
if (len != 0) {
|
||||
name_obj.resize(len);
|
||||
name_obj_.resize(len);
|
||||
CHECK_EQ(fi->Read(&name_obj_[0], len), len)
|
||||
<<"BoostLearner: wrong model format";
|
||||
}
|
||||
@@ -226,8 +236,10 @@ class LearnerImpl : public Learner {
|
||||
CHECK(fi->Read(&name_gbm_))
|
||||
<< "BoostLearner: wrong model format";
|
||||
// duplicated code with LazyInitModel
|
||||
obj_.reset(ObjFunction::Create(cfg_.at(name_obj_)));
|
||||
gbm_.reset(GradientBooster::Create(cfg_.at(name_gbm_)));
|
||||
obj_.reset(ObjFunction::Create(name_obj_));
|
||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||
gbm_->Load(fi);
|
||||
|
||||
if (metrics_.size() == 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
}
|
||||
@@ -246,11 +258,12 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
CHECK(ModelInitialized())
|
||||
<< "Always call InitModel or LoadModel before update";
|
||||
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
this->LazyInitModel();
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
|
||||
gbm_->DoBoost(train, this->FindBufferOffset(train), &gpair_);
|
||||
@@ -262,6 +275,7 @@ class LearnerImpl : public Learner {
|
||||
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
gbm_->DoBoost(train, this->FindBufferOffset(train), in_gpair);
|
||||
}
|
||||
|
||||
@@ -269,7 +283,8 @@ class LearnerImpl : public Learner {
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names) override {
|
||||
std::ostringstream os;
|
||||
os << '[' << iter << ']';
|
||||
os << '[' << iter << ']'
|
||||
<< std::setiosflags(std::ios::fixed);
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
this->PredictRaw(data_sets[i], &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
@@ -347,8 +362,6 @@ class LearnerImpl : public Learner {
|
||||
if (num_feature > mparam.num_feature) {
|
||||
mparam.num_feature = num_feature;
|
||||
}
|
||||
// reset the base score
|
||||
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
|
||||
|
||||
// setup
|
||||
cfg_["num_feature"] = ToString(mparam.num_feature);
|
||||
@@ -357,9 +370,13 @@ class LearnerImpl : public Learner {
|
||||
gbm_.reset(GradientBooster::Create(name_gbm_));
|
||||
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||
obj_->Configure(cfg_.begin(), cfg_.end());
|
||||
|
||||
// reset the base score
|
||||
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
|
||||
if (metrics_.size() == 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric()));
|
||||
}
|
||||
|
||||
this->base_score_ = mparam.base_score;
|
||||
gbm_->ResetPredBuffer(pred_buffer_size_);
|
||||
}
|
||||
@@ -373,6 +390,8 @@ class LearnerImpl : public Learner {
|
||||
inline void PredictRaw(DMatrix* data,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_.get() != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
gbm_->Predict(data,
|
||||
this->FindBufferOffset(data),
|
||||
out_preds,
|
||||
|
||||
20
src/logging.cc
Normal file
20
src/logging.cc
Normal file
@@ -0,0 +1,20 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file logging.cc
|
||||
* \brief Implementation of loggers.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/logging.h>
|
||||
#include <iostream>
|
||||
#include "./common/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
ConsoleLogger::~ConsoleLogger() {
|
||||
std::cout << log_stream_.str() << std::endl;
|
||||
}
|
||||
|
||||
TrackerLogger::~TrackerLogger() {
|
||||
log_stream_ << '\n';
|
||||
rabit::TrackerPrint(log_stream_.str());
|
||||
}
|
||||
} // namespace xgboost
|
||||
@@ -5,12 +5,16 @@
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
#include "../common/math.h"
|
||||
#include "../common/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(elementwise_metric);
|
||||
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
@@ -124,4 +128,3 @@ XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
42
src/metric/metric.cc
Normal file
42
src/metric/metric.cc
Normal file
@@ -0,0 +1,42 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file metric_registry.cc
|
||||
* \brief Registry of objective functions.
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
Metric* Metric::Create(const std::string& name) {
|
||||
std::string buf = name;
|
||||
std::string prefix = name;
|
||||
auto pos = buf.find('@');
|
||||
if (pos == std::string::npos) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown metric function " << name;
|
||||
}
|
||||
return (e->body)(nullptr);
|
||||
} else {
|
||||
std::string prefix = buf.substr(0, pos);
|
||||
auto *e = ::dmlc::Registry< ::xgboost::MetricReg>::Get()->Find(prefix.c_str());
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown metric function " << name;
|
||||
}
|
||||
return (e->body)(buf.substr(pos + 1, buf.length()).c_str());
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_metric);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_metric);
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
@@ -11,6 +11,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_metric);
|
||||
|
||||
/*!
|
||||
* \brief base class of multi-class evaluation
|
||||
* \tparam Derived the name of subclass
|
||||
@@ -114,4 +117,3 @@ XGBOOST_REGISTER_METRIC(MultiLogLoss, "mlogloss")
|
||||
.set_body([](const char* param) { return new EvalMultiLogLoss(); });
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@
|
||||
* \author Kailong Chen, Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/metric.h>
|
||||
#include <dmlc/registry.h>
|
||||
#include <cmath>
|
||||
#include "../common/sync.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace metric {
|
||||
// tag the this file, used by force static link later.
|
||||
DMLC_REGISTRY_FILE_TAG(rank_metric);
|
||||
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public Metric {
|
||||
public:
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj);
|
||||
|
||||
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
int num_class;
|
||||
// declare parameters
|
||||
|
||||
34
src/objective/objective.cc
Normal file
34
src/objective/objective.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file objective.cc
|
||||
* \brief Registry of all objective functions.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
// implement factory functions
|
||||
ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
|
||||
LOG(INFO) << "Objective candidate: " << entry->name;
|
||||
}
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@@ -4,8 +4,8 @@
|
||||
* \brief Definition of rank loss.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj);
|
||||
|
||||
struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
||||
int num_pairsample;
|
||||
float fix_list_weight;
|
||||
@@ -324,4 +326,3 @@ XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, "rank:map")
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
* \brief Definition of single-value regression and classification objectives.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj);
|
||||
|
||||
// common regressions
|
||||
// linear regression
|
||||
struct LinearSquareLoss {
|
||||
@@ -84,7 +87,9 @@ class RegLossObj : public ObjFunction {
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.size() << ", label.size=" << info.labels.size();
|
||||
out_gpair->resize(preds.size());
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
@@ -95,7 +100,7 @@ class RegLossObj : public ObjFunction {
|
||||
float p = Loss::PredTransform(preds[i]);
|
||||
float w = info.GetWeight(i);
|
||||
if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
|
||||
if (Loss::CheckLabel(info.labels[i])) label_correct = false;
|
||||
if (!Loss::CheckLabel(info.labels[i])) label_correct = false;
|
||||
out_gpair->at(i) = bst_gpair(Loss::FirstOrderGradient(p, info.labels[i]) * w,
|
||||
Loss::SecondOrderGradient(p, info.labels[i]) * w);
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
||||
.describe("L2 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(reg_alpha).set_lower_bound(0.0f).set_default(0.0f)
|
||||
.describe("L1 regularization on leaf weight");
|
||||
DMLC_DECLARE_FIELD(default_direction)
|
||||
DMLC_DECLARE_FIELD(default_direction).set_default(0)
|
||||
.add_enum("learn", 0)
|
||||
.add_enum("left", 1)
|
||||
.add_enum("right", 2)
|
||||
|
||||
35
src/tree/tree_updater.cc
Normal file
35
src/tree/tree_updater.cc
Normal file
@@ -0,0 +1,35 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file tree_updater.cc
|
||||
* \brief Registry of tree updaters.
|
||||
*/
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::TreeUpdaterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TreeUpdater* TreeUpdater::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown tree updater " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(updater_colmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_skmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_refresh);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_prune);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_histmaker);
|
||||
DMLC_REGISTRY_LINK_TAG(updater_sync);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
@@ -15,6 +15,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_colmaker);
|
||||
|
||||
/*! \brief column-wise update to construct a tree */
|
||||
template<typename TStats>
|
||||
class ColMaker: public TreeUpdater {
|
||||
@@ -891,4 +894,3 @@ XGBOOST_REGISTER_TREE_UPDATER(DistColMaker, "distcol")
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_histmaker);
|
||||
|
||||
template<typename TStats>
|
||||
class HistMaker: public BaseMaker {
|
||||
public:
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_prune);
|
||||
|
||||
/*! \brief pruner that prunes a tree after growing finishes */
|
||||
class TreePruner: public TreeUpdater {
|
||||
public:
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_refresh);
|
||||
|
||||
/*! \brief pruner that prunes a tree after growing finishs */
|
||||
template<typename TStats>
|
||||
class TreeRefresher: public TreeUpdater {
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_skmaker);
|
||||
|
||||
class SketchMaker: public BaseMaker {
|
||||
public:
|
||||
void Update(const std::vector<bst_gpair> &gpair,
|
||||
@@ -399,4 +401,3 @@ XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker")
|
||||
});
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_sync);
|
||||
|
||||
/*!
|
||||
* \brief syncher that synchronize the tree in all distributed nodes
|
||||
* can implement various strategies, so far it is always set to node 0's tree
|
||||
|
||||
Reference in New Issue
Block a user