Merge pull request #42 from tqchen/unity
Unity this is final minor change in data structure
This commit is contained in:
commit
4c023077dd
@ -42,11 +42,17 @@ class TreeModel {
|
|||||||
int max_depth;
|
int max_depth;
|
||||||
/*! \brief number of features used for tree construction */
|
/*! \brief number of features used for tree construction */
|
||||||
int num_feature;
|
int num_feature;
|
||||||
|
/*!
|
||||||
|
* \brief leaf vector size, used for vector tree
|
||||||
|
* used to store more than one dimensional information in tree
|
||||||
|
*/
|
||||||
|
int size_leaf_vector;
|
||||||
/*! \brief reserved part */
|
/*! \brief reserved part */
|
||||||
int reserved[32];
|
int reserved[31];
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
Param(void) {
|
Param(void) {
|
||||||
max_depth = 0;
|
max_depth = 0;
|
||||||
|
size_leaf_vector = 0;
|
||||||
memset(reserved, 0, sizeof(reserved));
|
memset(reserved, 0, sizeof(reserved));
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -57,6 +63,7 @@ class TreeModel {
|
|||||||
inline void SetParam(const char *name, const char *val) {
|
inline void SetParam(const char *name, const char *val) {
|
||||||
if (!strcmp("num_roots", name)) num_roots = atoi(val);
|
if (!strcmp("num_roots", name)) num_roots = atoi(val);
|
||||||
if (!strcmp("num_feature", name)) num_feature = atoi(val);
|
if (!strcmp("num_feature", name)) num_feature = atoi(val);
|
||||||
|
if (!strcmp("size_leaf_vector", name)) size_leaf_vector = atoi(val);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
/*! \brief tree node */
|
/*! \brief tree node */
|
||||||
@ -166,10 +173,12 @@ class TreeModel {
|
|||||||
protected:
|
protected:
|
||||||
// vector of nodes
|
// vector of nodes
|
||||||
std::vector<Node> nodes;
|
std::vector<Node> nodes;
|
||||||
// stats of nodes
|
|
||||||
std::vector<TNodeStat> stats;
|
|
||||||
// free node space, used during training process
|
// free node space, used during training process
|
||||||
std::vector<int> deleted_nodes;
|
std::vector<int> deleted_nodes;
|
||||||
|
// stats of nodes
|
||||||
|
std::vector<TNodeStat> stats;
|
||||||
|
// leaf vector, that is used to store additional information
|
||||||
|
std::vector<bst_float> leaf_vector;
|
||||||
// allocate a new node,
|
// allocate a new node,
|
||||||
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
// !!!!!! NOTE: may cause BUG here, nodes.resize
|
||||||
inline int AllocNode(void) {
|
inline int AllocNode(void) {
|
||||||
@ -184,6 +193,7 @@ class TreeModel {
|
|||||||
"number of nodes in the tree exceed 2^31");
|
"number of nodes in the tree exceed 2^31");
|
||||||
nodes.resize(param.num_nodes);
|
nodes.resize(param.num_nodes);
|
||||||
stats.resize(param.num_nodes);
|
stats.resize(param.num_nodes);
|
||||||
|
leaf_vector.resize(param.num_nodes * param.size_leaf_vector);
|
||||||
return nd;
|
return nd;
|
||||||
}
|
}
|
||||||
// delete a tree node
|
// delete a tree node
|
||||||
@ -247,6 +257,14 @@ class TreeModel {
|
|||||||
inline NodeStat &stat(int nid) {
|
inline NodeStat &stat(int nid) {
|
||||||
return stats[nid];
|
return stats[nid];
|
||||||
}
|
}
|
||||||
|
/*! \brief get leaf vector given nid */
|
||||||
|
inline bst_float* leafvec(int nid) {
|
||||||
|
return &leaf_vector[nid * param.size_leaf_vector];
|
||||||
|
}
|
||||||
|
/*! \brief get leaf vector given nid */
|
||||||
|
inline const bst_float* leafvec(int nid) const{
|
||||||
|
return &leaf_vector[nid * param.size_leaf_vector];
|
||||||
|
}
|
||||||
/*! \brief initialize the model */
|
/*! \brief initialize the model */
|
||||||
inline void InitModel(void) {
|
inline void InitModel(void) {
|
||||||
param.num_nodes = param.num_roots;
|
param.num_nodes = param.num_roots;
|
||||||
|
|||||||
@ -145,8 +145,8 @@ struct GradStats {
|
|||||||
double sum_grad;
|
double sum_grad;
|
||||||
/*! \brief sum hessian statistics */
|
/*! \brief sum hessian statistics */
|
||||||
double sum_hess;
|
double sum_hess;
|
||||||
/*! \brief constructor */
|
/*! \brief constructor, the object must be cleared during construction */
|
||||||
GradStats(void) {
|
explicit GradStats(const TrainParam ¶m) {
|
||||||
this->Clear();
|
this->Clear();
|
||||||
}
|
}
|
||||||
/*! \brief clear the statistics */
|
/*! \brief clear the statistics */
|
||||||
@ -169,29 +169,31 @@ struct GradStats {
|
|||||||
inline double CalcWeight(const TrainParam ¶m) const {
|
inline double CalcWeight(const TrainParam ¶m) const {
|
||||||
return param.CalcWeight(sum_grad, sum_hess);
|
return param.CalcWeight(sum_grad, sum_hess);
|
||||||
}
|
}
|
||||||
/*!\brief calculate gain of the solution */
|
/*! \brief calculate gain of the solution */
|
||||||
inline double CalcGain(const TrainParam ¶m) const {
|
inline double CalcGain(const TrainParam ¶m) const {
|
||||||
return param.CalcGain(sum_grad, sum_hess);
|
return param.CalcGain(sum_grad, sum_hess);
|
||||||
}
|
}
|
||||||
/*! \brief add statistics to the data */
|
/*! \brief add statistics to the data */
|
||||||
inline void Add(double grad, double hess) {
|
|
||||||
sum_grad += grad; sum_hess += hess;
|
|
||||||
}
|
|
||||||
/*! \brief add statistics to the data */
|
|
||||||
inline void Add(const GradStats &b) {
|
inline void Add(const GradStats &b) {
|
||||||
this->Add(b.sum_grad, b.sum_hess);
|
this->Add(b.sum_grad, b.sum_hess);
|
||||||
}
|
}
|
||||||
/*! \brief substract the statistics by b */
|
/*! \brief set current value to a - b */
|
||||||
inline GradStats Substract(const GradStats &b) const {
|
inline void SetSubstract(const GradStats &a, const GradStats &b) {
|
||||||
GradStats res;
|
sum_grad = a.sum_grad - b.sum_grad;
|
||||||
res.sum_grad = this->sum_grad - b.sum_grad;
|
sum_hess = a.sum_hess - b.sum_hess;
|
||||||
res.sum_hess = this->sum_hess - b.sum_hess;
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
/*! \return whether the statistics is not used yet */
|
/*! \return whether the statistics is not used yet */
|
||||||
inline bool Empty(void) const {
|
inline bool Empty(void) const {
|
||||||
return sum_hess == 0.0;
|
return sum_hess == 0.0;
|
||||||
}
|
}
|
||||||
|
/*! \brief set leaf vector value based on statistics */
|
||||||
|
inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{
|
||||||
|
}
|
||||||
|
protected:
|
||||||
|
/*! \brief add statistics to the data */
|
||||||
|
inline void Add(double grad, double hess) {
|
||||||
|
sum_grad += grad; sum_hess += hess;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -51,8 +51,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
/*! \brief current best solution */
|
/*! \brief current best solution */
|
||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
// constructor
|
// constructor
|
||||||
ThreadEntry(void) {
|
explicit ThreadEntry(const TrainParam ¶m)
|
||||||
stats.Clear();
|
: stats(param) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
struct NodeEntry {
|
struct NodeEntry {
|
||||||
@ -65,8 +65,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
/*! \brief current best solution */
|
/*! \brief current best solution */
|
||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
// constructor
|
// constructor
|
||||||
NodeEntry(void) : root_gain(0.0f), weight(0.0f){
|
explicit NodeEntry(const TrainParam ¶m)
|
||||||
stats.Clear();
|
: stats(param), root_gain(0.0f), weight(0.0f){
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// actual builder that runs the algorithm
|
// actual builder that runs the algorithm
|
||||||
@ -100,6 +100,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
|
p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg;
|
||||||
p_tree->stat(nid).base_weight = snode[nid].weight;
|
p_tree->stat(nid).base_weight = snode[nid].weight;
|
||||||
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
|
p_tree->stat(nid).sum_hess = static_cast<float>(snode[nid].stats.sum_hess);
|
||||||
|
snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,9 +180,9 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
const RegTree &tree) {
|
const RegTree &tree) {
|
||||||
{// setup statistics space for each tree node
|
{// setup statistics space for each tree node
|
||||||
for (size_t i = 0; i < stemp.size(); ++i) {
|
for (size_t i = 0; i < stemp.size(); ++i) {
|
||||||
stemp[i].resize(tree.param.num_nodes, ThreadEntry());
|
stemp[i].resize(tree.param.num_nodes, ThreadEntry(param));
|
||||||
}
|
}
|
||||||
snode.resize(tree.param.num_nodes, NodeEntry());
|
snode.resize(tree.param.num_nodes, NodeEntry(param));
|
||||||
}
|
}
|
||||||
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
const std::vector<bst_uint> &rowset = fmat.buffered_rowset();
|
||||||
// setup position
|
// setup position
|
||||||
@ -196,7 +197,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
// sum the per thread statistics together
|
// sum the per thread statistics together
|
||||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||||
const int nid = qexpand[j];
|
const int nid = qexpand[j];
|
||||||
TStats stats; stats.Clear();
|
TStats stats(param);
|
||||||
for (size_t tid = 0; tid < stemp.size(); ++tid) {
|
for (size_t tid = 0; tid < stemp.size(); ++tid) {
|
||||||
stats.Add(stemp[tid][nid].stats);
|
stats.Add(stemp[tid][nid].stats);
|
||||||
}
|
}
|
||||||
@ -231,6 +232,8 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
for (size_t j = 0; j < qexpand.size(); ++j) {
|
for (size_t j = 0; j < qexpand.size(); ++j) {
|
||||||
temp[qexpand[j]].stats.Clear();
|
temp[qexpand[j]].stats.Clear();
|
||||||
}
|
}
|
||||||
|
// left statistics
|
||||||
|
TStats c(param);
|
||||||
while (it.Next()) {
|
while (it.Next()) {
|
||||||
const bst_uint ridx = it.rindex();
|
const bst_uint ridx = it.rindex();
|
||||||
const int nid = position[ridx];
|
const int nid = position[ridx];
|
||||||
@ -246,7 +249,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
} else {
|
} else {
|
||||||
// try to find a split
|
// try to find a split
|
||||||
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) {
|
||||||
TStats c = snode[nid].stats.Substract(e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (c.sum_hess >= param.min_child_weight) {
|
if (c.sum_hess >= param.min_child_weight) {
|
||||||
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search);
|
||||||
@ -261,7 +264,7 @@ class ColMaker: public IUpdater<FMatrix> {
|
|||||||
for (size_t i = 0; i < qexpand.size(); ++i) {
|
for (size_t i = 0; i < qexpand.size(); ++i) {
|
||||||
const int nid = qexpand[i];
|
const int nid = qexpand[i];
|
||||||
ThreadEntry &e = temp[nid];
|
ThreadEntry &e = temp[nid];
|
||||||
TStats c = snode[nid].stats.Substract(e.stats);
|
c.SetSubstract(snode[nid].stats, e.stats);
|
||||||
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) {
|
||||||
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain;
|
||||||
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
const float delta = is_forward_search ? rt_eps : -rt_eps;
|
||||||
|
|||||||
@ -44,8 +44,8 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
std::vector<TStats> &vec = stemp[tid * trees.size() + i];
|
std::vector<TStats> &vec = stemp[tid * trees.size() + i];
|
||||||
vec.resize(trees[i]->param.num_nodes);
|
vec.resize(trees[i]->param.num_nodes, TStats(param));
|
||||||
std::fill(vec.begin(), vec.end(), TStats());
|
std::fill(vec.begin(), vec.end(), TStats(param));
|
||||||
}
|
}
|
||||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||||
}
|
}
|
||||||
@ -114,6 +114,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
|
|||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
tree.stat(nid).base_weight = gstats[nid].CalcWeight(param);
|
||||||
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
tree.stat(nid).sum_hess = static_cast<float>(gstats[nid].sum_hess);
|
||||||
|
gstats[nid].SetLeafVec(param, tree.leafvec(nid));
|
||||||
if (tree[nid].is_leaf()) {
|
if (tree[nid].is_leaf()) {
|
||||||
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@ -19,6 +19,7 @@ xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
|||||||
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
|
xglib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
|
||||||
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
||||||
|
|
||||||
xglib.XGBoosterCreate.restype = ctypes.c_void_p
|
xglib.XGBoosterCreate.restype = ctypes.c_void_p
|
||||||
@ -27,10 +28,10 @@ xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
|
|||||||
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
|
||||||
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length):
|
def ctypes2numpy(cptr, length, dtype):
|
||||||
# convert a ctypes pointer array to numpy
|
# convert a ctypes pointer array to numpy
|
||||||
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
|
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
|
||||||
res = numpy.zeros(length, dtype='float32')
|
res = numpy.zeros(length, dtype=dtype)
|
||||||
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
|
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -76,23 +77,31 @@ class DMatrix:
|
|||||||
# destructor
|
# destructor
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
xglib.XGDMatrixFree(self.handle)
|
xglib.XGDMatrixFree(self.handle)
|
||||||
def __get_float_info(self, field):
|
def get_float_info(self, field):
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
||||||
ctypes.byref(length))
|
ctypes.byref(length))
|
||||||
return ctypes2numpy(ret, length.value)
|
return ctypes2numpy(ret, length.value, 'float32')
|
||||||
def __set_float_info(self, field, data):
|
def get_uint_info(self, field):
|
||||||
xglib.XGDMatrixSetFloatInfo(self.handle,ctypes.c_char_p(field.encode('utf-8')),
|
length = ctypes.c_ulong()
|
||||||
|
ret = xglib.XGDMatrixGetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
||||||
|
ctypes.byref(length))
|
||||||
|
return ctypes2numpy(ret, length.value, 'uint32')
|
||||||
|
def set_float_info(self, field, data):
|
||||||
|
xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
||||||
(ctypes.c_float*len(data))(*data), len(data))
|
(ctypes.c_float*len(data))(*data), len(data))
|
||||||
|
def set_uint_info(self, field, data):
|
||||||
|
xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
|
||||||
|
(ctypes.c_uint*len(data))(*data), len(data))
|
||||||
# load data from file
|
# load data from file
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
|
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
|
||||||
# set label of dmatrix
|
# set label of dmatrix
|
||||||
def set_label(self, label):
|
def set_label(self, label):
|
||||||
self.__set_float_info('label', label)
|
self.set_float_info('label', label)
|
||||||
# set weight of each instances
|
# set weight of each instances
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight):
|
||||||
self.__set_float_info('weight', weight)
|
self.set_float_info('weight', weight)
|
||||||
# set initialized margin prediction
|
# set initialized margin prediction
|
||||||
def set_base_margin(self, margin):
|
def set_base_margin(self, margin):
|
||||||
"""
|
"""
|
||||||
@ -103,19 +112,19 @@ class DMatrix:
|
|||||||
e.g. for logistic regression: need to put in value before logistic transformation
|
e.g. for logistic regression: need to put in value before logistic transformation
|
||||||
see also example/demo.py
|
see also example/demo.py
|
||||||
"""
|
"""
|
||||||
self.__set_float_info('base_margin', margin)
|
self.set_float_info('base_margin', margin)
|
||||||
# set group size of dmatrix, used for rank
|
# set group size of dmatrix, used for rank
|
||||||
def set_group(self, group):
|
def set_group(self, group):
|
||||||
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
|
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
|
||||||
# get label from dmatrix
|
# get label from dmatrix
|
||||||
def get_label(self):
|
def get_label(self):
|
||||||
return self.__get_float_info('label')
|
return self.get_float_info('label')
|
||||||
# get weight from dmatrix
|
# get weight from dmatrix
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
return self.__get_float_info('weight')
|
return self.get_float_info('weight')
|
||||||
# get base_margin from dmatrix
|
# get base_margin from dmatrix
|
||||||
def get_base_margin(self):
|
def get_base_margin(self):
|
||||||
return self.__get_float_info('base_margin')
|
return self.get_float_info('base_margin')
|
||||||
def num_row(self):
|
def num_row(self):
|
||||||
return xglib.XGDMatrixNumRow(self.handle)
|
return xglib.XGDMatrixNumRow(self.handle)
|
||||||
# slice the DMatrix to return a new DMatrix that only contains rindex
|
# slice the DMatrix to return a new DMatrix that only contains rindex
|
||||||
@ -189,7 +198,7 @@ class Booster:
|
|||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
preds = xglib.XGBoosterPredict(self.handle, data.handle,
|
||||||
int(output_margin), ctypes.byref(length))
|
int(output_margin), ctypes.byref(length))
|
||||||
return ctypes2numpy(preds, length.value)
|
return ctypes2numpy(preds, length.value, 'float32')
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file """
|
""" save model to file """
|
||||||
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))
|
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))
|
||||||
|
|||||||
@ -88,10 +88,10 @@ extern "C"{
|
|||||||
mat.row_data_.resize(nelem);
|
mat.row_data_.resize(nelem);
|
||||||
for (size_t i = 0; i < nelem; ++i) {
|
for (size_t i = 0; i < nelem; ++i) {
|
||||||
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
|
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
|
||||||
mat.info.num_col = std::max(mat.info.num_col,
|
mat.info.info.num_col = std::max(mat.info.info.num_col,
|
||||||
static_cast<size_t>(indices[i]+1));
|
static_cast<size_t>(indices[i]+1));
|
||||||
}
|
}
|
||||||
mat.info.num_row = nindptr - 1;
|
mat.info.info.num_row = nindptr - 1;
|
||||||
return p_mat;
|
return p_mat;
|
||||||
}
|
}
|
||||||
void* XGDMatrixCreateFromMat(const float *data,
|
void* XGDMatrixCreateFromMat(const float *data,
|
||||||
@ -100,8 +100,8 @@ extern "C"{
|
|||||||
float missing) {
|
float missing) {
|
||||||
DMatrixSimple *p_mat = new DMatrixSimple();
|
DMatrixSimple *p_mat = new DMatrixSimple();
|
||||||
DMatrixSimple &mat = *p_mat;
|
DMatrixSimple &mat = *p_mat;
|
||||||
mat.info.num_row = nrow;
|
mat.info.info.num_row = nrow;
|
||||||
mat.info.num_col = ncol;
|
mat.info.info.num_col = ncol;
|
||||||
for (size_t i = 0; i < nrow; ++i, data += ncol) {
|
for (size_t i = 0; i < nrow; ++i, data += ncol) {
|
||||||
size_t nelem = 0;
|
size_t nelem = 0;
|
||||||
for (size_t j = 0; j < ncol; ++j) {
|
for (size_t j = 0; j < ncol; ++j) {
|
||||||
@ -130,8 +130,8 @@ extern "C"{
|
|||||||
utils::Check(src.info.group_ptr.size() == 0,
|
utils::Check(src.info.group_ptr.size() == 0,
|
||||||
"slice does not support group structure");
|
"slice does not support group structure");
|
||||||
ret.Clear();
|
ret.Clear();
|
||||||
ret.info.num_row = len;
|
ret.info.info.num_row = len;
|
||||||
ret.info.num_col = src.info.num_col;
|
ret.info.info.num_col = src.info.num_col();
|
||||||
|
|
||||||
utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator();
|
utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator();
|
||||||
iter->BeforeFirst();
|
iter->BeforeFirst();
|
||||||
@ -165,10 +165,16 @@ extern "C"{
|
|||||||
}
|
}
|
||||||
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) {
|
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) {
|
||||||
std::vector<float> &vec =
|
std::vector<float> &vec =
|
||||||
static_cast<DataMatrix*>(handle)->info.GetInfo(field);
|
static_cast<DataMatrix*>(handle)->info.GetFloatInfo(field);
|
||||||
vec.resize(len);
|
vec.resize(len);
|
||||||
memcpy(&vec[0], info, sizeof(float) * len);
|
memcpy(&vec[0], info, sizeof(float) * len);
|
||||||
}
|
}
|
||||||
|
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *info, size_t len) {
|
||||||
|
std::vector<unsigned> &vec =
|
||||||
|
static_cast<DataMatrix*>(handle)->info.GetUIntInfo(field);
|
||||||
|
vec.resize(len);
|
||||||
|
memcpy(&vec[0], info, sizeof(unsigned) * len);
|
||||||
|
}
|
||||||
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
|
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
|
||||||
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
DataMatrix *pmat = static_cast<DataMatrix*>(handle);
|
||||||
pmat->info.group_ptr.resize(len + 1);
|
pmat->info.group_ptr.resize(len + 1);
|
||||||
@ -179,12 +185,18 @@ extern "C"{
|
|||||||
}
|
}
|
||||||
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) {
|
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) {
|
||||||
const std::vector<float> &vec =
|
const std::vector<float> &vec =
|
||||||
static_cast<const DataMatrix*>(handle)->info.GetInfo(field);
|
static_cast<const DataMatrix*>(handle)->info.GetFloatInfo(field);
|
||||||
|
*len = vec.size();
|
||||||
|
return &vec[0];
|
||||||
|
}
|
||||||
|
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, size_t* len) {
|
||||||
|
const std::vector<unsigned> &vec =
|
||||||
|
static_cast<const DataMatrix*>(handle)->info.GetUIntInfo(field);
|
||||||
*len = vec.size();
|
*len = vec.size();
|
||||||
return &vec[0];
|
return &vec[0];
|
||||||
}
|
}
|
||||||
size_t XGDMatrixNumRow(const void *handle) {
|
size_t XGDMatrixNumRow(const void *handle) {
|
||||||
return static_cast<const DataMatrix*>(handle)->info.num_row;
|
return static_cast<const DataMatrix*>(handle)->info.num_row();
|
||||||
}
|
}
|
||||||
|
|
||||||
// xgboost implementation
|
// xgboost implementation
|
||||||
|
|||||||
@ -69,6 +69,14 @@ extern "C" {
|
|||||||
* \param len length of array
|
* \param len length of array
|
||||||
*/
|
*/
|
||||||
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len);
|
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len);
|
||||||
|
/*!
|
||||||
|
* \brief set uint32 vector to a content in info
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param field field name
|
||||||
|
* \param array pointer to float vector
|
||||||
|
* \param len length of array
|
||||||
|
*/
|
||||||
|
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *array, size_t len);
|
||||||
/*!
|
/*!
|
||||||
* \brief set label of the training matrix
|
* \brief set label of the training matrix
|
||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
@ -81,9 +89,17 @@ extern "C" {
|
|||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
* \param field field name
|
* \param field field name
|
||||||
* \param out_len used to set result length
|
* \param out_len used to set result length
|
||||||
* \return pointer to the label
|
* \return pointer to the result
|
||||||
*/
|
*/
|
||||||
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
|
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
|
||||||
|
/*!
|
||||||
|
* \brief get uint32 info vector from matrix
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param field field name
|
||||||
|
* \param out_len used to set result length
|
||||||
|
* \return pointer to the result
|
||||||
|
*/
|
||||||
|
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, size_t* out_len);
|
||||||
/*!
|
/*!
|
||||||
* \brief return number of rows
|
* \brief return number of rows
|
||||||
*/
|
*/
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user