finish dump

This commit is contained in:
tqchen 2014-08-23 13:09:47 -07:00
parent 40da2fa2c0
commit 3ba7995754
4 changed files with 112 additions and 7 deletions

View File

@ -65,6 +65,18 @@ xgb.eval <- function(booster, watchlist, iter) {
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames) msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
return(msg) return(msg)
} }
xgb.save <- function(handle, fname) {
if (typeof(fname) == "character") {
stop("xgb.save: fname must be character");
}
if (class(handle) != "xgb.Booster") {
.Call("XGBoosterSaveModel_R", handle, fname);
return(TRUE)
}
if (class(handle) != "xgb.DMatrix") {
}
}
# test code here # test code here

View File

@ -165,7 +165,7 @@ extern "C" {
* \param out_len length of output array * \param out_len length of output array
* \return char *data[], representing dump of each model * \return char *data[], representing dump of each model
*/ */
const char** XGBoosterDumpModel(void *handle, const char *fmap, const char **XGBoosterDumpModel(void *handle, const char *fmap,
size_t *out_len); size_t *out_len);
}; };
#endif // XGBOOST_WRAPPER_H_ #endif // XGBOOST_WRAPPER_H_

View File

@ -3,6 +3,8 @@
#include "xgboost_wrapper.h" #include "xgboost_wrapper.h"
#include "xgboost_wrapper_R.h" #include "xgboost_wrapper_R.h"
#include "../src/utils/utils.h" #include "../src/utils/utils.h"
#include "../src/utils/omp.h"
using namespace xgboost; using namespace xgboost;
extern "C" { extern "C" {
@ -11,13 +13,17 @@ extern "C" {
XGDMatrixFree(R_ExternalPtrAddr(ext)); XGDMatrixFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext); R_ClearExternalPtr(ext);
} }
SEXP XGDMatrixCreateFromFile_R(SEXP fname) { SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), 0); void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
return ret; return ret;
} }
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent));
}
// functions related to booster // functions related to booster
void _BoosterFinalizer(SEXP ext) { void _BoosterFinalizer(SEXP ext) {
@ -47,6 +53,19 @@ extern "C" {
asInteger(iter), asInteger(iter),
R_ExternalPtrAddr(dtrain)); R_ExternalPtrAddr(dtrain));
} }
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
utils::Check(length(grad) == length(hess), "gradient and hess must have same length");
int len = length(grad);
std::vector<float> tgrad(len), thess(len);
#pragma omp parallel for schedule(static)
for (int j = 0; j < len; ++j) {
tgrad[j] = REAL(grad)[j];
thess[j] = REAL(hess)[j];
}
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain),
&tgrad[0], &thess[0], len);
}
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length"); utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
int len = length(dmats); int len = length(dmats);
@ -62,4 +81,35 @@ extern "C" {
asInteger(iter), asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len)); &vec_dmats[0], &vec_sptr[0], len));
} }
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
size_t olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(output_margin),
&olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];
}
UNPROTECT(1);
return ret;
}
void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) {
size_t olen;
const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)),
&olen);
FILE *fo = utils::FopenCheck(CHAR(asChar(fname)), "w");
for (size_t i = 0; i < olen; ++i) {
fprintf(fo, "booster[%lu]:\n", i);
fprintf(fo, "%s\n", res[i]);
}
fclose(fo);
}
} }

View File

@ -12,10 +12,18 @@ extern "C" {
extern "C" { extern "C" {
/*! /*!
* \brief load a data matrix * \brief load a data matrix
* \fname name of the content * \param fname name of the content
* \param silent whether print messages
* \return a loaded data matrix * \return a loaded data matrix
*/ */
SEXP XGDMatrixCreateFromFile_R(SEXP fname); SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
/*!
* \brief load a data matrix into binary file
* \param handle a instance of data matrix
* \param fname file name
* \param silent print statistics when saving
*/
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
/*! /*!
* \brief create xgboost learner * \brief create xgboost learner
* \param dmats a list of dmatrix handles that will be cached * \param dmats a list of dmatrix handles that will be cached
@ -35,6 +43,15 @@ extern "C" {
* \param dtrain training data * \param dtrain training data
*/ */
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain); void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to replace UpdateOneIter, to support customized loss function
* \param handle handle
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
*/
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);
/*! /*!
* \brief get evaluation statistics for xgboost * \brief get evaluation statistics for xgboost
* \param handle handle * \param handle handle
@ -44,5 +61,31 @@ extern "C" {
* \return the string containing evaluation stati * \return the string containing evaluation stati
*/ */
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames); SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
/*!
* \brief make prediction based on dmat
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
*/
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin);
/*!
* \brief load model from existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterLoadModel_R(SEXP handle, SEXP fname);
/*!
* \brief save model into existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
/*!
* \brief dump model into text file
* \param handle handle
* \param fname file name of model that can be dumped into
* \param fmap name to fmap can be empty string
*/
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap);
}; };
#endif // XGBOOST_WRAPPER_H_ #endif // XGBOOST_WRAPPER_R_H_