Implement feature score for linear model. (#7048)
* Add feature score support for linear model. * Port R interface to the new implementation. * Add linear model support in Python. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -96,41 +96,44 @@ xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL,
|
||||
if (!(is.null(feature_names) || is.character(feature_names)))
|
||||
stop("feature_names: Has to be a character vector")
|
||||
|
||||
model_text_dump <- xgb.dump(model = model, with_stats = TRUE)
|
||||
|
||||
# linear model
|
||||
if (model_text_dump[2] == "bias:"){
|
||||
weight_index <- which(model_text_dump == "weight:") + 1
|
||||
weights <- as.numeric(
|
||||
model_text_dump[weight_index:length(model_text_dump)]
|
||||
model <- xgb.Booster.complete(model)
|
||||
config <- jsonlite::fromJSON(xgb.config(model))
|
||||
if (config$learner$gradient_booster$name == "gblinear") {
|
||||
args <- list(importance_type = "weight", feature_names = feature_names)
|
||||
results <- .Call(
|
||||
XGBoosterFeatureScore_R, model$handle, jsonlite::toJSON(args, auto_unbox = TRUE, null = "null")
|
||||
)
|
||||
|
||||
num_class <- NVL(model$params$num_class, 1)
|
||||
if (is.null(feature_names))
|
||||
feature_names <- seq(to = length(weights) / num_class) - 1
|
||||
if (length(feature_names) * num_class != length(weights))
|
||||
stop("feature_names length does not match the number of features used in the model")
|
||||
|
||||
result <- if (num_class == 1) {
|
||||
data.table(Feature = feature_names, Weight = weights)[order(-abs(Weight))]
|
||||
names(results) <- c("features", "shape", "weight")
|
||||
n_classes <- if (length(results$shape) == 2) { results$shape[2] } else { 0 }
|
||||
importance <- if (n_classes == 0) {
|
||||
data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))]
|
||||
} else {
|
||||
data.table(Feature = rep(feature_names, each = num_class),
|
||||
Weight = weights,
|
||||
Class = seq_len(num_class) - 1)[order(Class, -abs(Weight))]
|
||||
data.table(
|
||||
Feature = rep(results$features, each = n_classes), Weight = results$weight, Class = seq_len(n_classes) - 1
|
||||
)[order(Class, -abs(Weight))]
|
||||
}
|
||||
} else { # tree model
|
||||
result <- xgb.model.dt.tree(feature_names = feature_names,
|
||||
text = model_text_dump,
|
||||
trees = trees)[
|
||||
Feature != "Leaf", .(Gain = sum(Quality),
|
||||
Cover = sum(Cover),
|
||||
Frequency = .N), by = Feature][
|
||||
, `:=`(Gain = Gain / sum(Gain),
|
||||
Cover = Cover / sum(Cover),
|
||||
Frequency = Frequency / sum(Frequency))][
|
||||
order(Gain, decreasing = TRUE)]
|
||||
} else {
|
||||
concatenated <- list()
|
||||
output_names <- vector()
|
||||
for (importance_type in c("weight", "gain", "cover")) {
|
||||
args <- list(importance_type = importance_type, feature_names = feature_names)
|
||||
results <- .Call(
|
||||
XGBoosterFeatureScore_R, model$handle, jsonlite::toJSON(args, auto_unbox = TRUE, null = "null")
|
||||
)
|
||||
names(results) <- c("features", "shape", importance_type)
|
||||
concatenated[
|
||||
switch(importance_type, "weight" = "Frequency", "gain" = "Gain", "cover" = "Cover")
|
||||
] <- results[importance_type]
|
||||
output_names <- results$features
|
||||
}
|
||||
importance <- data.table(
|
||||
Feature = output_names,
|
||||
Gain = concatenated$Gain / sum(concatenated$Gain),
|
||||
Cover = concatenated$Cover / sum(concatenated$Cover),
|
||||
Frequency = concatenated$Frequency / sum(concatenated$Frequency)
|
||||
)[order(Gain, decreasing = TRUE)]
|
||||
}
|
||||
result
|
||||
importance
|
||||
}
|
||||
|
||||
# Avoid error messages during CRAN check.
|
||||
|
||||
@@ -47,6 +47,7 @@ extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||
extern SEXP XGBGetGlobalConfig_R();
|
||||
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
||||
|
||||
static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4},
|
||||
@@ -81,6 +82,7 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||
{"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2},
|
||||
{NULL, NULL, 0}
|
||||
};
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@
|
||||
|
||||
using namespace dmlc;
|
||||
|
||||
SEXP XGCheckNullPtr_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) {
|
||||
return ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
|
||||
}
|
||||
|
||||
void _DMatrixFinalizer(SEXP ext) {
|
||||
XGB_DLL void _DMatrixFinalizer(SEXP ext) {
|
||||
R_API_BEGIN();
|
||||
if (R_ExternalPtrAddr(ext) == NULL) return;
|
||||
CHECK_CALL(XGDMatrixFree(R_ExternalPtrAddr(ext)));
|
||||
@@ -50,14 +50,14 @@ void _DMatrixFinalizer(SEXP ext) {
|
||||
R_API_END();
|
||||
}
|
||||
|
||||
SEXP XGBSetGlobalConfig_R(SEXP json_str) {
|
||||
XGB_DLL SEXP XGBSetGlobalConfig_R(SEXP json_str) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBSetGlobalConfig(CHAR(asChar(json_str))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBGetGlobalConfig_R() {
|
||||
XGB_DLL SEXP XGBGetGlobalConfig_R() {
|
||||
const char* json_str;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBGetGlobalConfig(&json_str));
|
||||
@@ -65,7 +65,7 @@ SEXP XGBGetGlobalConfig_R() {
|
||||
return mkString(json_str);
|
||||
}
|
||||
|
||||
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
|
||||
XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
DMatrixHandle handle;
|
||||
@@ -77,8 +77,7 @@ SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
||||
SEXP missing) {
|
||||
XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
SEXP dim = getAttrib(mat, R_DimSymbol);
|
||||
@@ -112,10 +111,8 @@ SEXP XGDMatrixCreateFromMat_R(SEXP mat,
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
|
||||
SEXP indices,
|
||||
SEXP data,
|
||||
SEXP num_row) {
|
||||
XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data,
|
||||
SEXP num_row) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
const int *p_indptr = INTEGER(indptr);
|
||||
@@ -151,7 +148,7 @@ SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
||||
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int len = length(idxset);
|
||||
@@ -171,7 +168,7 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
||||
XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(fname)),
|
||||
@@ -180,7 +177,7 @@ SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
int len = length(array);
|
||||
const char *name = CHAR(asChar(field));
|
||||
@@ -214,7 +211,7 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
@@ -232,7 +229,7 @@ SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGDMatrixNumRow_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGDMatrixNumRow_R(SEXP handle) {
|
||||
bst_ulong nrow;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGDMatrixNumRow(R_ExternalPtrAddr(handle), &nrow));
|
||||
@@ -240,7 +237,7 @@ SEXP XGDMatrixNumRow_R(SEXP handle) {
|
||||
return ScalarInteger(static_cast<int>(nrow));
|
||||
}
|
||||
|
||||
SEXP XGDMatrixNumCol_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle) {
|
||||
bst_ulong ncol;
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGDMatrixNumCol(R_ExternalPtrAddr(handle), &ncol));
|
||||
@@ -255,7 +252,7 @@ void _BoosterFinalizer(SEXP ext) {
|
||||
R_ClearExternalPtr(ext);
|
||||
}
|
||||
|
||||
SEXP XGBoosterCreate_R(SEXP dmats) {
|
||||
XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
int len = length(dmats);
|
||||
@@ -272,7 +269,7 @@ SEXP XGBoosterCreate_R(SEXP dmats) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
|
||||
XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
|
||||
R_API_BEGIN();
|
||||
int len = length(dmats);
|
||||
std::vector<void*> dvec;
|
||||
@@ -287,7 +284,7 @@ SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
|
||||
XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(name)),
|
||||
@@ -296,7 +293,7 @@ SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
||||
XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
|
||||
asInteger(iter),
|
||||
@@ -305,7 +302,7 @@ SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
|
||||
XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
|
||||
R_API_BEGIN();
|
||||
CHECK_EQ(length(grad), length(hess))
|
||||
<< "gradient and hess must have same length";
|
||||
@@ -328,7 +325,7 @@ SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||
XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||
const char *ret;
|
||||
R_API_BEGIN();
|
||||
CHECK_EQ(length(dmats), length(evnames))
|
||||
@@ -353,8 +350,8 @@ SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||
return mkString(ret);
|
||||
}
|
||||
|
||||
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
|
||||
SEXP ntree_limit, SEXP training) {
|
||||
XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
|
||||
SEXP ntree_limit, SEXP training) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
@@ -374,7 +371,7 @@ SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask,
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||
SEXP r_out_shape;
|
||||
SEXP r_out_result;
|
||||
SEXP r_out;
|
||||
@@ -413,21 +410,21 @@ SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||
return r_out;
|
||||
}
|
||||
|
||||
SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
|
||||
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
|
||||
XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
@@ -442,7 +439,7 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
||||
RAW(raw),
|
||||
@@ -451,7 +448,7 @@ SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||
const char* ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong len {0};
|
||||
@@ -462,14 +459,14 @@ SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||
return mkString(ret);
|
||||
}
|
||||
|
||||
SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
|
||||
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value))));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong out_len;
|
||||
@@ -484,7 +481,7 @@ SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
|
||||
XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
|
||||
RAW(raw),
|
||||
@@ -493,7 +490,7 @@ SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
|
||||
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
|
||||
SEXP out;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
@@ -530,7 +527,7 @@ SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_for
|
||||
return out;
|
||||
}
|
||||
|
||||
SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
|
||||
XGB_DLL SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
|
||||
SEXP out;
|
||||
R_API_BEGIN();
|
||||
int success;
|
||||
@@ -550,7 +547,7 @@ SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
|
||||
return out;
|
||||
}
|
||||
|
||||
SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
|
||||
XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
|
||||
R_API_BEGIN();
|
||||
const char *v = isNull(val) ? nullptr : CHAR(asChar(val));
|
||||
CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle),
|
||||
@@ -559,7 +556,7 @@ SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
SEXP XGBoosterGetAttrNames_R(SEXP handle) {
|
||||
XGB_DLL SEXP XGBoosterGetAttrNames_R(SEXP handle) {
|
||||
SEXP out;
|
||||
R_API_BEGIN();
|
||||
bst_ulong len;
|
||||
@@ -578,3 +575,51 @@ SEXP XGBoosterGetAttrNames_R(SEXP handle) {
|
||||
UNPROTECT(1);
|
||||
return out;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
|
||||
SEXP out_features_sexp;
|
||||
SEXP out_scores_sexp;
|
||||
SEXP out_shape_sexp;
|
||||
SEXP r_out;
|
||||
|
||||
R_API_BEGIN();
|
||||
char const *c_json_config = CHAR(asChar(json_config));
|
||||
bst_ulong out_n_features;
|
||||
char const **out_features;
|
||||
|
||||
bst_ulong out_dim;
|
||||
bst_ulong const *out_shape;
|
||||
float const *out_scores;
|
||||
|
||||
CHECK_CALL(XGBoosterFeatureScore(R_ExternalPtrAddr(handle), c_json_config,
|
||||
&out_n_features, &out_features,
|
||||
&out_dim, &out_shape, &out_scores));
|
||||
|
||||
out_shape_sexp = PROTECT(allocVector(INTSXP, out_dim));
|
||||
size_t len = 1;
|
||||
for (size_t i = 0; i < out_dim; ++i) {
|
||||
INTEGER(out_shape_sexp)[i] = out_shape[i];
|
||||
len *= out_shape[i];
|
||||
}
|
||||
|
||||
out_scores_sexp = PROTECT(allocVector(REALSXP, len));
|
||||
#pragma omp parallel for
|
||||
for (omp_ulong i = 0; i < len; ++i) {
|
||||
REAL(out_scores_sexp)[i] = out_scores[i];
|
||||
}
|
||||
|
||||
out_features_sexp = PROTECT(allocVector(STRSXP, out_n_features));
|
||||
for (size_t i = 0; i < out_n_features; ++i) {
|
||||
SET_STRING_ELT(out_features_sexp, i, mkChar(out_features[i]));
|
||||
}
|
||||
|
||||
r_out = PROTECT(allocVector(VECSXP, 3));
|
||||
SET_VECTOR_ELT(r_out, 0, out_features_sexp);
|
||||
SET_VECTOR_ELT(r_out, 1, out_shape_sexp);
|
||||
SET_VECTOR_ELT(r_out, 2, out_scores_sexp);
|
||||
|
||||
R_API_END();
|
||||
UNPROTECT(4);
|
||||
|
||||
return r_out;
|
||||
}
|
||||
|
||||
@@ -275,4 +275,12 @@ XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val);
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterGetAttrNames_R(SEXP handle);
|
||||
|
||||
/*!
|
||||
* \brief Get feature scores from the model.
|
||||
* \param json_config See `XGBoosterFeatureScore` in xgboost c_api.h
|
||||
* \return A vector with the first element as feature names, second element as shape of
|
||||
* feature scores and thrid element as feature scores.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config);
|
||||
|
||||
#endif // XGBOOST_WRAPPER_R_H_ // NOLINT(*)
|
||||
|
||||
Reference in New Issue
Block a user