change omp loop var to bst_omp_uint, add XGB_DLL to wrapper
This commit is contained in:
parent
97467fe807
commit
7739f57c8b
@ -3,10 +3,10 @@
|
||||
#include <utility>
|
||||
#include <cstring>
|
||||
#include "xgboost_R.h"
|
||||
#include "../../wrapper/xgboost_wrapper.h"
|
||||
#include "../../src/utils/utils.h"
|
||||
#include "../../src/utils/omp.h"
|
||||
#include "../../src/utils/matrix_csr.h"
|
||||
#include "xgboost_wrapper.h"
|
||||
#include "../src/utils/utils.h"
|
||||
#include "../src/utils/omp.h"
|
||||
#include "../src/utils/matrix_csr.h"
|
||||
|
||||
using namespace xgboost;
|
||||
// implements error handling
|
||||
@ -119,7 +119,7 @@ extern "C" {
|
||||
}
|
||||
}
|
||||
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
size_t olen;
|
||||
uint64_t olen;
|
||||
const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)), &olen);
|
||||
SEXP ret = PROTECT(allocVector(REALSXP, olen));
|
||||
@ -188,7 +188,7 @@ extern "C" {
|
||||
&vec_dmats[0], &vec_sptr[0], len));
|
||||
}
|
||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
|
||||
size_t olen;
|
||||
uint64_t olen;
|
||||
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
|
||||
R_ExternalPtrAddr(dmat),
|
||||
asInteger(output_margin),
|
||||
@ -207,13 +207,13 @@ extern "C" {
|
||||
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
|
||||
}
|
||||
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) {
|
||||
size_t olen;
|
||||
uint64_t olen;
|
||||
const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(fmap)),
|
||||
&olen);
|
||||
FILE *fo = utils::FopenCheck(CHAR(asChar(fname)), "w");
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
fprintf(fo, "booster[%u]:\n", static_cast<unsigned>(i));
|
||||
fprintf(fo, "booster[%lu]:\n", i);
|
||||
fprintf(fo, "%s", res[i]);
|
||||
}
|
||||
fclose(fo);
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include "utils/io.h"
|
||||
#include "utils/omp.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/iterator.h"
|
||||
#include "utils/random.h"
|
||||
@ -370,9 +371,9 @@ class FMatrixS : public FMatrixInterface<FMatrixS>{
|
||||
}
|
||||
|
||||
// sort columns
|
||||
unsigned ncol = static_cast<unsigned>(this->NumCol());
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(this->NumCol());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ncol; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
std::sort(&col_data_[0] + col_ptr_[i],
|
||||
&col_data_[0] + col_ptr_[i + 1], Entry::CmpValue);
|
||||
}
|
||||
|
||||
@ -51,9 +51,9 @@ class GBLinear : public IGradBooster<FMatrix> {
|
||||
// for all the output group
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
const unsigned ndata = static_cast<unsigned>(rowset.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||
if (p.hess >= 0.0f) {
|
||||
sum_grad += p.grad; sum_hess += p.hess;
|
||||
@ -65,7 +65,7 @@ class GBLinear : public IGradBooster<FMatrix> {
|
||||
model.bias()[gid] += dw;
|
||||
// update grad value
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||
if (p.hess >= 0.0f) {
|
||||
p.grad += p.hess * dw;
|
||||
@ -73,9 +73,9 @@ class GBLinear : public IGradBooster<FMatrix> {
|
||||
}
|
||||
}
|
||||
// number of features
|
||||
const unsigned nfeat = static_cast<unsigned>(feat_index.size());
|
||||
const bst_omp_uint nfeat = static_cast<bst_omp_uint>(feat_index.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < nfeat; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
const bst_uint fid = feat_index[i];
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
@ -117,9 +117,9 @@ class GBLinear : public IGradBooster<FMatrix> {
|
||||
// k is number of group
|
||||
preds.resize(preds.size() + batch.size * ngroup);
|
||||
// parallel over local batch
|
||||
const unsigned nsize = static_cast<unsigned>(batch.size);
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < nsize; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const size_t ridx = batch.base_rowid + i;
|
||||
// loop over output groups
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
|
||||
@ -94,8 +94,9 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
"must have exactly ngroup*nrow gpairs");
|
||||
std::vector<bst_gpair> tmp(gpair.size()/ngroup);
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp[i] = gpair[i * ngroup + gid];
|
||||
}
|
||||
this->BoostNewTrees(tmp, fmat, info, gid);
|
||||
@ -129,9 +130,9 @@ class GBTree : public IGradBooster<FMatrix> {
|
||||
// k is number of group
|
||||
preds.resize(preds.size() + batch.size * mparam.num_output_group);
|
||||
// parallel over local batch
|
||||
const unsigned nsize = static_cast<unsigned>(batch.size);
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < nsize; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const int tid = omp_get_thread_num();
|
||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||
|
||||
@ -26,10 +26,10 @@ struct EvalEWiseBase : public IEvaluator {
|
||||
const MetaInfo &info) const {
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"label and prediction size not match");
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
float sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += Derived::EvalRow(info.labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
@ -109,12 +109,12 @@ struct EvalAMS : public IEvaluator {
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
utils::Check(info.weights.size() == ndata, "we need weight to evaluate ams");
|
||||
std::vector< std::pair<float, unsigned> > rec(ndata);
|
||||
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), CmpFirst);
|
||||
@ -211,7 +211,7 @@ struct EvalAuc : public IEvaluator {
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Check(gptr.back() == preds.size(),
|
||||
"EvalAuc: group structure must match number of prediction");
|
||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statictis
|
||||
double sum_auc = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_auc)
|
||||
@ -219,7 +219,7 @@ struct EvalAuc : public IEvaluator {
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned k = 0; k < ngroup; ++k) {
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], j));
|
||||
@ -269,7 +269,7 @@ struct EvalRankList : public IEvaluator {
|
||||
utils::Assert(gptr.size() != 0, "must specify group when constructing rank file");
|
||||
utils::Assert(gptr.back() == preds.size(),
|
||||
"EvalRanklist: group structure must match number of prediction");
|
||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
// sum statistics
|
||||
double sum_metric = 0.0f;
|
||||
#pragma omp parallel reduction(+:sum_metric)
|
||||
@ -277,7 +277,7 @@ struct EvalRankList : public IEvaluator {
|
||||
// each thread takes a local rec
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned k = 0; k < ngroup; ++k) {
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
rec.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
|
||||
rec.push_back(std::make_pair(preds[j], static_cast<int>(info.labels[j])));
|
||||
|
||||
@ -253,17 +253,17 @@ class BoostLearner {
|
||||
data.info.info, out_preds);
|
||||
// add base margin
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
if (data.info.base_margin.size() != 0) {
|
||||
utils::Check(preds.size() == data.info.base_margin.size(),
|
||||
"base_margin.size does not match with prediction size");
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
preds[j] += data.info.base_margin[j];
|
||||
}
|
||||
} else {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
preds[j] += mparam.base_score;
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,9 +116,9 @@ class RegLossObj : public IObjFunction{
|
||||
gpair.resize(preds.size());
|
||||
// start calculating gradient
|
||||
const unsigned nstep = static_cast<unsigned>(info.labels.size());
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const unsigned j = i % nstep;
|
||||
float p = loss.PredTransform(preds[i]);
|
||||
float w = info.GetWeight(j);
|
||||
@ -132,9 +132,9 @@ class RegLossObj : public IObjFunction{
|
||||
}
|
||||
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||
std::vector<float> &preds = *io_preds;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
preds[j] = loss.PredTransform(preds[j]);
|
||||
}
|
||||
}
|
||||
@ -169,12 +169,12 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
const unsigned nstep = static_cast<unsigned>(info.labels.size() * nclass);
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size() / nclass);
|
||||
const unsigned ndata = static_cast<bst_omp_uint>(preds.size() / nclass);
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[i * nclass + k];
|
||||
}
|
||||
@ -210,13 +210,13 @@ class SoftmaxMultiClassObj : public IObjFunction {
|
||||
utils::Check(nclass != 0, "must set num_class to use softmax");
|
||||
std::vector<float> &preds = *io_preds;
|
||||
std::vector<float> tmp;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size()/nclass);
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size()/nclass);
|
||||
if (prob == 0) tmp.resize(ndata);
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
}
|
||||
@ -263,7 +263,7 @@ class LambdaRankObj : public IObjFunction {
|
||||
const std::vector<unsigned> &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr;
|
||||
utils::Check(gptr.size() != 0 && gptr.back() == info.labels.size(),
|
||||
"group structure not consistent with #rows");
|
||||
const unsigned ngroup = static_cast<unsigned>(gptr.size() - 1);
|
||||
const bst_omp_uint ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
|
||||
#pragma omp parallel
|
||||
{
|
||||
// parall construct, declare random number generator here, so that each
|
||||
@ -273,7 +273,7 @@ class LambdaRankObj : public IObjFunction {
|
||||
std::vector<ListEntry> lst;
|
||||
std::vector< std::pair<float, unsigned> > rec;
|
||||
#pragma omp for schedule(static)
|
||||
for (unsigned k = 0; k < ngroup; ++k) {
|
||||
for (bst_omp_uint k = 0; k < ngroup; ++k) {
|
||||
lst.clear(); pairs.clear();
|
||||
for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) {
|
||||
lst.push_back(ListEntry(preds[j], info.labels[j], j));
|
||||
|
||||
@ -186,9 +186,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
}
|
||||
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
||||
// setup position
|
||||
const unsigned ndata = static_cast<unsigned>(rowset.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (position[ridx] < 0) continue;
|
||||
@ -286,12 +286,12 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
feat_set.resize(n);
|
||||
}
|
||||
// start enumeration
|
||||
const unsigned nsize = static_cast<unsigned>(feat_set.size());
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(feat_set.size());
|
||||
#if defined(_OPENMP)
|
||||
const int batch_size = std::max(static_cast<int>(nsize / this->nthread / 32), 1);
|
||||
#endif
|
||||
#pragma omp parallel for schedule(dynamic, batch_size)
|
||||
for (unsigned i = 0; i < nsize; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
const unsigned fid = feat_set[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
if (param.need_forward_search(fmat.GetColDensity(fid))) {
|
||||
@ -321,9 +321,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
inline void ResetPosition(const std::vector<int> &qexpand, const FMatrix &fmat, const RegTree &tree) {
|
||||
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
||||
// step 1, set default direct nodes to default, and leaf nodes to -1
|
||||
const unsigned ndata = static_cast<unsigned>(rowset.size());
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
const bst_uint ridx = rowset[i];
|
||||
const int nid = position[ridx];
|
||||
if (nid >= 0) {
|
||||
@ -344,9 +344,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
||||
std::sort(fsplits.begin(), fsplits.end());
|
||||
fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin());
|
||||
// start put things into right place
|
||||
const unsigned nfeats = static_cast<unsigned>(fsplits.size());
|
||||
const bst_omp_uint nfeats = static_cast<bst_omp_uint>(fsplits.size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (unsigned i = 0; i < nfeats; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nfeats; ++i) {
|
||||
const unsigned fid = fsplits[i];
|
||||
for (typename FMatrix::ColIter it = fmat.GetSortedCol(fid); it.Next();) {
|
||||
const bst_uint ridx = it.rindex();
|
||||
|
||||
@ -56,9 +56,9 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
||||
const SparseBatch &batch = iter->Value();
|
||||
utils::Check(batch.size < std::numeric_limits<unsigned>::max(),
|
||||
"too large batch size ");
|
||||
const unsigned nbatch = static_cast<unsigned>(batch.size);
|
||||
const bst_omp_uint nbatch = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned i = 0; i < nbatch; ++i) {
|
||||
for (bst_omp_uint i = 0; i < nbatch; ++i) {
|
||||
SparseBatch::Inst inst = batch[i];
|
||||
const int tid = omp_get_thread_num();
|
||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
|
||||
@ -7,6 +7,15 @@
|
||||
*/
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
namespace xgboost {
|
||||
// loop variable used in openmp
|
||||
#ifdef _MSC_VER
|
||||
typedef int bst_omp_uint;
|
||||
#else
|
||||
typedef unsigned bst_omp_uint;
|
||||
#endif
|
||||
} // namespace xgboost
|
||||
|
||||
#else
|
||||
#ifndef DISABLE_OPENMP
|
||||
#ifndef _MSC_VER
|
||||
|
||||
@ -213,7 +213,7 @@ extern "C" {
|
||||
&olen);
|
||||
FILE *fo = utils::FopenCheck(CHAR(asChar(fname)), "w");
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
fprintf(fo, "booster[%lu]:\n", i);
|
||||
fprintf(fo, "booster[%u]:\n", static_cast<unsigned>(i));
|
||||
fprintf(fo, "%s", res[i]);
|
||||
}
|
||||
fclose(fo);
|
||||
|
||||
@ -32,9 +32,9 @@ class Booster: public learner::BoostLearner<FMatrixS> {
|
||||
inline void BoostOneIter(const DataMatrix &train,
|
||||
float *grad, float *hess, uint64_t len) {
|
||||
this->gpair_.resize(len);
|
||||
const unsigned ndata = static_cast<unsigned>(len);
|
||||
const bst_omp_uint ndata = static_cast<bst_omp_uint>(len);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (unsigned j = 0; j < ndata; ++j) {
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
gpair_[j] = bst_gpair(grad[j], hess[j]);
|
||||
}
|
||||
gbm_->DoBoost(train.fmat, train.info.info, &gpair_);
|
||||
|
||||
@ -9,13 +9,14 @@
|
||||
#include <cstdio>
|
||||
// define uint64_t to be unsigned long
|
||||
typedef unsigned long uint64_t;
|
||||
#define XGB_DLL
|
||||
|
||||
extern "C" {
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \return a loaded data matrix
|
||||
*/
|
||||
void* XGDMatrixCreateFromFile(const char *fname, int silent);
|
||||
XGB_DLL void* XGDMatrixCreateFromFile(const char *fname, int silent);
|
||||
/*!
|
||||
* \brief create a matrix content from csr format
|
||||
* \param indptr pointer to row headers
|
||||
@ -25,7 +26,7 @@ extern "C" {
|
||||
* \param nelem number of nonzero elements in the matrix
|
||||
* \return created dmatrix
|
||||
*/
|
||||
void* XGDMatrixCreateFromCSR(const uint64_t *indptr,
|
||||
XGB_DLL void* XGDMatrixCreateFromCSR(const uint64_t *indptr,
|
||||
const unsigned *indices,
|
||||
const float *data,
|
||||
uint64_t nindptr,
|
||||
@ -38,7 +39,7 @@ extern "C" {
|
||||
* \param missing which value to represent missing value
|
||||
* \return created dmatrix
|
||||
*/
|
||||
void* XGDMatrixCreateFromMat(const float *data,
|
||||
XGB_DLL void* XGDMatrixCreateFromMat(const float *data,
|
||||
uint64_t nrow,
|
||||
uint64_t ncol,
|
||||
float missing);
|
||||
@ -49,20 +50,20 @@ extern "C" {
|
||||
* \param len length of index set
|
||||
* \return a sliced new matrix
|
||||
*/
|
||||
void* XGDMatrixSliceDMatrix(void *handle,
|
||||
XGB_DLL void* XGDMatrixSliceDMatrix(void *handle,
|
||||
const int *idxset,
|
||||
uint64_t len);
|
||||
/*!
|
||||
* \brief free space in data matrix
|
||||
*/
|
||||
void XGDMatrixFree(void *handle);
|
||||
XGB_DLL void XGDMatrixFree(void *handle);
|
||||
/*!
|
||||
* \brief load a data matrix into binary file
|
||||
* \param handle a instance of data matrix
|
||||
* \param fname file name
|
||||
* \param silent print statistics when saving
|
||||
*/
|
||||
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
|
||||
XGB_DLL void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
|
||||
/*!
|
||||
* \brief set float vector to a content in info
|
||||
* \param handle a instance of data matrix
|
||||
@ -70,7 +71,7 @@ extern "C" {
|
||||
* \param array pointer to float vector
|
||||
* \param len length of array
|
||||
*/
|
||||
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, uint64_t len);
|
||||
XGB_DLL void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, uint64_t len);
|
||||
/*!
|
||||
* \brief set uint32 vector to a content in info
|
||||
* \param handle a instance of data matrix
|
||||
@ -78,14 +79,14 @@ extern "C" {
|
||||
* \param array pointer to float vector
|
||||
* \param len length of array
|
||||
*/
|
||||
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *array, uint64_t len);
|
||||
XGB_DLL void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *array, uint64_t len);
|
||||
/*!
|
||||
* \brief set label of the training matrix
|
||||
* \param handle a instance of data matrix
|
||||
* \param group pointer to group size
|
||||
* \param len length of array
|
||||
*/
|
||||
void XGDMatrixSetGroup(void *handle, const unsigned *group, uint64_t len);
|
||||
XGB_DLL void XGDMatrixSetGroup(void *handle, const unsigned *group, uint64_t len);
|
||||
/*!
|
||||
* \brief get float info vector from matrix
|
||||
* \param handle a instance of data matrix
|
||||
@ -93,7 +94,7 @@ extern "C" {
|
||||
* \param out_len used to set result length
|
||||
* \return pointer to the result
|
||||
*/
|
||||
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, uint64_t* out_len);
|
||||
XGB_DLL const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, uint64_t* out_len);
|
||||
/*!
|
||||
* \brief get uint32 info vector from matrix
|
||||
* \param handle a instance of data matrix
|
||||
@ -101,37 +102,37 @@ extern "C" {
|
||||
* \param out_len used to set result length
|
||||
* \return pointer to the result
|
||||
*/
|
||||
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, uint64_t* out_len);
|
||||
XGB_DLL const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, uint64_t* out_len);
|
||||
/*!
|
||||
* \brief return number of rows
|
||||
*/
|
||||
uint64_t XGDMatrixNumRow(const void *handle);
|
||||
XGB_DLL uint64_t XGDMatrixNumRow(const void *handle);
|
||||
// --- start XGBoost class
|
||||
/*!
|
||||
* \brief create xgboost learner
|
||||
* \param dmats matrices that are set to be cached
|
||||
* \param len length of dmats
|
||||
*/
|
||||
void *XGBoosterCreate(void* dmats[], uint64_t len);
|
||||
XGB_DLL void *XGBoosterCreate(void* dmats[], uint64_t len);
|
||||
/*!
|
||||
* \brief free obj in handle
|
||||
* \param handle handle to be freed
|
||||
*/
|
||||
void XGBoosterFree(void* handle);
|
||||
XGB_DLL void XGBoosterFree(void* handle);
|
||||
/*!
|
||||
* \brief set parameters
|
||||
* \param handle handle
|
||||
* \param name parameter name
|
||||
* \param val value of parameter
|
||||
*/
|
||||
void XGBoosterSetParam(void *handle, const char *name, const char *value);
|
||||
XGB_DLL void XGBoosterSetParam(void *handle, const char *name, const char *value);
|
||||
/*!
|
||||
* \brief update the model in one round using dtrain
|
||||
* \param handle handle
|
||||
* \param iter current iteration rounds
|
||||
* \param dtrain training data
|
||||
*/
|
||||
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
|
||||
XGB_DLL void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
|
||||
/*!
|
||||
* \brief update the model, by directly specify gradient and second order gradient,
|
||||
* this can be used to replace UpdateOneIter, to support customized loss function
|
||||
@ -141,7 +142,7 @@ extern "C" {
|
||||
* \param hess second order gradient statistics
|
||||
* \param len length of grad/hess array
|
||||
*/
|
||||
void XGBoosterBoostOneIter(void *handle, void *dtrain,
|
||||
XGB_DLL void XGBoosterBoostOneIter(void *handle, void *dtrain,
|
||||
float *grad, float *hess, uint64_t len);
|
||||
/*!
|
||||
* \brief get evaluation statistics for xgboost
|
||||
@ -152,7 +153,7 @@ extern "C" {
|
||||
* \param len length of dmats
|
||||
* \return the string containing evaluation stati
|
||||
*/
|
||||
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
||||
XGB_DLL const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
||||
const char *evnames[], uint64_t len);
|
||||
/*!
|
||||
* \brief make prediction based on dmat
|
||||
@ -161,19 +162,19 @@ extern "C" {
|
||||
* \param output_margin whether only output raw margin value
|
||||
* \param len used to store length of returning result
|
||||
*/
|
||||
const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, uint64_t *len);
|
||||
XGB_DLL const float *XGBoosterPredict(void *handle, void *dmat, int output_margin, uint64_t *len);
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
* \param fname file name
|
||||
*/
|
||||
void XGBoosterLoadModel(void *handle, const char *fname);
|
||||
XGB_DLL void XGBoosterLoadModel(void *handle, const char *fname);
|
||||
/*!
|
||||
* \brief save model into existing file
|
||||
* \param handle handle
|
||||
* \param fname file name
|
||||
*/
|
||||
void XGBoosterSaveModel(const void *handle, const char *fname);
|
||||
XGB_DLL void XGBoosterSaveModel(const void *handle, const char *fname);
|
||||
/*!
|
||||
* \brief dump model, return array of strings representing model dump
|
||||
* \param handle handle
|
||||
@ -181,7 +182,7 @@ extern "C" {
|
||||
* \param out_len length of output array
|
||||
* \return char *data[], representing dump of each model
|
||||
*/
|
||||
const char **XGBoosterDumpModel(void *handle, const char *fmap,
|
||||
XGB_DLL const char **XGBoosterDumpModel(void *handle, const char *fmap,
|
||||
uint64_t *out_len);
|
||||
};
|
||||
#endif // XGBOOST_WRAPPER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user