From 663136aa08c00598d8b49adf5901e4cb2ce187da Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 25 Jun 2021 14:34:02 +0800 Subject: [PATCH] Implement feature score for linear model. (#7048) * Add feature score support for linear model. * Port R interface to the new implementation. * Add linear model support in Python. Co-authored-by: Philip Hyunsu Cho --- R-package/R/xgb.importance.R | 65 +++++----- R-package/src/init.c | 2 + R-package/src/xgboost_R.cc | 121 ++++++++++++------ R-package/src/xgboost_R.h | 8 ++ demo/guide-python/basic_walkthrough.py | 4 +- include/xgboost/c_api.h | 26 ++-- include/xgboost/gbm.h | 4 +- include/xgboost/linalg.h | 8 ++ .../dmlc/xgboost4j/java/BoosterImplTest.java | 81 ++++++------ python-package/xgboost/core.py | 83 +++++------- python-package/xgboost/sklearn.py | 34 ++--- src/c_api/c_api.cc | 62 +++++---- src/c_api/c_api_utils.h | 20 ++- src/gbm/gblinear.cc | 22 ++++ src/gbm/gbtree.h | 11 +- src/learner.cc | 17 --- tests/python/test_basic.py | 3 + tests/python/test_with_sklearn.py | 28 ++++ 18 files changed, 367 insertions(+), 232 deletions(-) diff --git a/R-package/R/xgb.importance.R b/R-package/R/xgb.importance.R index 7305ee571..5176a9d54 100644 --- a/R-package/R/xgb.importance.R +++ b/R-package/R/xgb.importance.R @@ -96,41 +96,44 @@ xgb.importance <- function(feature_names = NULL, model = NULL, trees = NULL, if (!(is.null(feature_names) || is.character(feature_names))) stop("feature_names: Has to be a character vector") - model_text_dump <- xgb.dump(model = model, with_stats = TRUE) - - # linear model - if (model_text_dump[2] == "bias:"){ - weight_index <- which(model_text_dump == "weight:") + 1 - weights <- as.numeric( - model_text_dump[weight_index:length(model_text_dump)] + model <- xgb.Booster.complete(model) + config <- jsonlite::fromJSON(xgb.config(model)) + if (config$learner$gradient_booster$name == "gblinear") { + args <- list(importance_type = "weight", feature_names = feature_names) + results <- .Call( + XGBoosterFeatureScore_R, model$handle, jsonlite::toJSON(args, auto_unbox = TRUE, null = "null") ) - - num_class <- NVL(model$params$num_class, 1) - if (is.null(feature_names)) - feature_names <- seq(to = length(weights) / num_class) - 1 - if (length(feature_names) * num_class != length(weights)) - stop("feature_names length does not match the number of features used in the model") - - result <- if (num_class == 1) { - data.table(Feature = feature_names, Weight = weights)[order(-abs(Weight))] + names(results) <- c("features", "shape", "weight") + n_classes <- if (length(results$shape) == 2) { results$shape[2] } else { 0 } + importance <- if (n_classes == 0) { + data.table(Feature = results$features, Weight = results$weight)[order(-abs(Weight))] } else { - data.table(Feature = rep(feature_names, each = num_class), - Weight = weights, - Class = seq_len(num_class) - 1)[order(Class, -abs(Weight))] + data.table( + Feature = rep(results$features, each = n_classes), Weight = results$weight, Class = seq_len(n_classes) - 1 + )[order(Class, -abs(Weight))] } - } else { # tree model - result <- xgb.model.dt.tree(feature_names = feature_names, - text = model_text_dump, - trees = trees)[ - Feature != "Leaf", .(Gain = sum(Quality), - Cover = sum(Cover), - Frequency = .N), by = Feature][ - , `:=`(Gain = Gain / sum(Gain), - Cover = Cover / sum(Cover), - Frequency = Frequency / sum(Frequency))][ - order(Gain, decreasing = TRUE)] + } else { + concatenated <- list() + output_names <- vector() + for (importance_type in c("weight", "gain", "cover")) { + args <- list(importance_type = importance_type, feature_names = feature_names) + results <- .Call( + XGBoosterFeatureScore_R, model$handle, jsonlite::toJSON(args, auto_unbox = TRUE, null = "null") + ) + names(results) <- c("features", "shape", importance_type) + concatenated[ + switch(importance_type, "weight" = "Frequency", "gain" = "Gain", "cover" = "Cover") + ] <- results[importance_type] + output_names <- results$features + } + importance <- data.table( + Feature = output_names, + Gain = concatenated$Gain / sum(concatenated$Gain), + Cover = concatenated$Cover / sum(concatenated$Cover), + Frequency = concatenated$Frequency / sum(concatenated$Frequency) + )[order(Gain, decreasing = TRUE)] } - result + importance } # Avoid error messages during CRAN check. diff --git a/R-package/src/init.c b/R-package/src/init.c index 789a0b625..5f136ff22 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -47,6 +47,7 @@ extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP); extern SEXP XGBSetGlobalConfig_R(SEXP); extern SEXP XGBGetGlobalConfig_R(); +extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); static const R_CallMethodDef CallEntries[] = { {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4}, @@ -81,6 +82,7 @@ static const R_CallMethodDef CallEntries[] = { {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2}, {"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1}, {"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0}, + {"XGBoosterFeatureScore_R", (DL_FUNC) &XGBoosterFeatureScore_R, 2}, {NULL, NULL, 0} }; diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 0746da30f..56fb61f8d 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -38,11 +38,11 @@ using namespace dmlc; -SEXP XGCheckNullPtr_R(SEXP handle) { +XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) { return ScalarLogical(R_ExternalPtrAddr(handle) == NULL); } -void _DMatrixFinalizer(SEXP ext) { +XGB_DLL void _DMatrixFinalizer(SEXP ext) { R_API_BEGIN(); if (R_ExternalPtrAddr(ext) == NULL) return; CHECK_CALL(XGDMatrixFree(R_ExternalPtrAddr(ext))); @@ -50,14 +50,14 @@ void _DMatrixFinalizer(SEXP ext) { R_API_END(); } -SEXP XGBSetGlobalConfig_R(SEXP json_str) { +XGB_DLL SEXP XGBSetGlobalConfig_R(SEXP json_str) { R_API_BEGIN(); CHECK_CALL(XGBSetGlobalConfig(CHAR(asChar(json_str)))); R_API_END(); return R_NilValue; } -SEXP XGBGetGlobalConfig_R() { +XGB_DLL SEXP XGBGetGlobalConfig_R() { const char* json_str; R_API_BEGIN(); CHECK_CALL(XGBGetGlobalConfig(&json_str)); @@ -65,7 +65,7 @@ SEXP XGBGetGlobalConfig_R() { return mkString(json_str); } -SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) { +XGB_DLL SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) { SEXP ret; R_API_BEGIN(); DMatrixHandle handle; @@ -77,8 +77,7 @@ SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) { return ret; } -SEXP XGDMatrixCreateFromMat_R(SEXP mat, - SEXP missing) { +XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing) { SEXP ret; R_API_BEGIN(); SEXP dim = getAttrib(mat, R_DimSymbol); @@ -112,10 +111,8 @@ SEXP XGDMatrixCreateFromMat_R(SEXP mat, return ret; } -SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, - SEXP indices, - SEXP data, - SEXP num_row) { +XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, + SEXP num_row) { SEXP ret; R_API_BEGIN(); const int *p_indptr = INTEGER(indptr); @@ -151,7 +148,7 @@ SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, return ret; } -SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { +XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { SEXP ret; R_API_BEGIN(); int len = length(idxset); @@ -171,7 +168,7 @@ SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { return ret; } -SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { +XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { R_API_BEGIN(); CHECK_CALL(XGDMatrixSaveBinary(R_ExternalPtrAddr(handle), CHAR(asChar(fname)), @@ -180,7 +177,7 @@ SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { return R_NilValue; } -SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { +XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { R_API_BEGIN(); int len = length(array); const char *name = CHAR(asChar(field)); @@ -214,7 +211,7 @@ SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { return R_NilValue; } -SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) { +XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) { SEXP ret; R_API_BEGIN(); bst_ulong olen; @@ -232,7 +229,7 @@ SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) { return ret; } -SEXP XGDMatrixNumRow_R(SEXP handle) { +XGB_DLL SEXP XGDMatrixNumRow_R(SEXP handle) { bst_ulong nrow; R_API_BEGIN(); CHECK_CALL(XGDMatrixNumRow(R_ExternalPtrAddr(handle), &nrow)); @@ -240,7 +237,7 @@ SEXP XGDMatrixNumRow_R(SEXP handle) { return ScalarInteger(static_cast(nrow)); } -SEXP XGDMatrixNumCol_R(SEXP handle) { +XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle) { bst_ulong ncol; R_API_BEGIN(); CHECK_CALL(XGDMatrixNumCol(R_ExternalPtrAddr(handle), &ncol)); @@ -255,7 +252,7 @@ void _BoosterFinalizer(SEXP ext) { R_ClearExternalPtr(ext); } -SEXP XGBoosterCreate_R(SEXP dmats) { +XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) { SEXP ret; R_API_BEGIN(); int len = length(dmats); @@ -272,7 +269,7 @@ SEXP XGBoosterCreate_R(SEXP dmats) { return ret; } -SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) { +XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) { R_API_BEGIN(); int len = length(dmats); std::vector dvec; @@ -287,7 +284,7 @@ SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) { return R_NilValue; } -SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { +XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { R_API_BEGIN(); CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle), CHAR(asChar(name)), @@ -296,7 +293,7 @@ SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { return R_NilValue; } -SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) { +XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) { R_API_BEGIN(); CHECK_CALL(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle), asInteger(iter), @@ -305,7 +302,7 @@ SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) { return R_NilValue; } -SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { +XGB_DLL SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { R_API_BEGIN(); CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length"; @@ -328,7 +325,7 @@ SEXP XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { return R_NilValue; } -SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { +XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { const char *ret; R_API_BEGIN(); CHECK_EQ(length(dmats), length(evnames)) @@ -353,8 +350,8 @@ SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { return mkString(ret); } -SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, - SEXP ntree_limit, SEXP training) { +XGB_DLL SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, + SEXP ntree_limit, SEXP training) { SEXP ret; R_API_BEGIN(); bst_ulong olen; @@ -374,7 +371,7 @@ SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, return ret; } -SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) { +XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) { SEXP r_out_shape; SEXP r_out_result; SEXP r_out; @@ -413,21 +410,21 @@ SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) { return r_out; } -SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) { +XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) { R_API_BEGIN(); CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)))); R_API_END(); return R_NilValue; } -SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) { +XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) { R_API_BEGIN(); CHECK_CALL(XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)))); R_API_END(); return R_NilValue; } -SEXP XGBoosterModelToRaw_R(SEXP handle) { +XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) { SEXP ret; R_API_BEGIN(); bst_ulong olen; @@ -442,7 +439,7 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) { return ret; } -SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { +XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { R_API_BEGIN(); CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle), RAW(raw), @@ -451,7 +448,7 @@ SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { return R_NilValue; } -SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { +XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { const char* ret; R_API_BEGIN(); bst_ulong len {0}; @@ -462,14 +459,14 @@ SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { return mkString(ret); } -SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) { +XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) { R_API_BEGIN(); CHECK_CALL(XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value)))); R_API_END(); return R_NilValue; } -SEXP XGBoosterSerializeToBuffer_R(SEXP handle) { +XGB_DLL SEXP XGBoosterSerializeToBuffer_R(SEXP handle) { SEXP ret; R_API_BEGIN(); bst_ulong out_len; @@ -484,7 +481,7 @@ SEXP XGBoosterSerializeToBuffer_R(SEXP handle) { return ret; } -SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) { +XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) { R_API_BEGIN(); CHECK_CALL(XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle), RAW(raw), @@ -493,7 +490,7 @@ SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) { return R_NilValue; } -SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) { +XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) { SEXP out; R_API_BEGIN(); bst_ulong olen; @@ -530,7 +527,7 @@ SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_for return out; } -SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) { +XGB_DLL SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) { SEXP out; R_API_BEGIN(); int success; @@ -550,7 +547,7 @@ SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) { return out; } -SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) { +XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) { R_API_BEGIN(); const char *v = isNull(val) ? nullptr : CHAR(asChar(val)); CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle), @@ -559,7 +556,7 @@ SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) { return R_NilValue; } -SEXP XGBoosterGetAttrNames_R(SEXP handle) { +XGB_DLL SEXP XGBoosterGetAttrNames_R(SEXP handle) { SEXP out; R_API_BEGIN(); bst_ulong len; @@ -578,3 +575,51 @@ SEXP XGBoosterGetAttrNames_R(SEXP handle) { UNPROTECT(1); return out; } + +XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) { + SEXP out_features_sexp; + SEXP out_scores_sexp; + SEXP out_shape_sexp; + SEXP r_out; + + R_API_BEGIN(); + char const *c_json_config = CHAR(asChar(json_config)); + bst_ulong out_n_features; + char const **out_features; + + bst_ulong out_dim; + bst_ulong const *out_shape; + float const *out_scores; + + CHECK_CALL(XGBoosterFeatureScore(R_ExternalPtrAddr(handle), c_json_config, + &out_n_features, &out_features, + &out_dim, &out_shape, &out_scores)); + + out_shape_sexp = PROTECT(allocVector(INTSXP, out_dim)); + size_t len = 1; + for (size_t i = 0; i < out_dim; ++i) { + INTEGER(out_shape_sexp)[i] = out_shape[i]; + len *= out_shape[i]; + } + + out_scores_sexp = PROTECT(allocVector(REALSXP, len)); +#pragma omp parallel for + for (omp_ulong i = 0; i < len; ++i) { + REAL(out_scores_sexp)[i] = out_scores[i]; + } + + out_features_sexp = PROTECT(allocVector(STRSXP, out_n_features)); + for (size_t i = 0; i < out_n_features; ++i) { + SET_STRING_ELT(out_features_sexp, i, mkChar(out_features[i])); + } + + r_out = PROTECT(allocVector(VECSXP, 3)); + SET_VECTOR_ELT(r_out, 0, out_features_sexp); + SET_VECTOR_ELT(r_out, 1, out_shape_sexp); + SET_VECTOR_ELT(r_out, 2, out_scores_sexp); + + R_API_END(); + UNPROTECT(4); + + return r_out; +} diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index dcf5327dd..467aa54a3 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -275,4 +275,12 @@ XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val); */ XGB_DLL SEXP XGBoosterGetAttrNames_R(SEXP handle); +/*! + * \brief Get feature scores from the model. + * \param json_config See `XGBoosterFeatureScore` in xgboost c_api.h + * \return A vector with the first element as feature names, second element as shape of + * feature scores and thrid element as feature scores. + */ +XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config); + #endif // XGBOOST_WRAPPER_R_H_ // NOLINT(*) diff --git a/demo/guide-python/basic_walkthrough.py b/demo/guide-python/basic_walkthrough.py index a76def962..c977a4f48 100644 --- a/demo/guide-python/basic_walkthrough.py +++ b/demo/guide-python/basic_walkthrough.py @@ -11,8 +11,8 @@ DEMO_DIR = os.path.join(XGBOOST_ROOT_DIR, 'demo') # simple example # load file from text file, also binary buffer generated by xgboost -dtrain = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train')) -dtest = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.test')) +dtrain = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.train?indexing_mode=1')) +dtest = xgb.DMatrix(os.path.join(DEMO_DIR, 'data', 'agaricus.txt.test?indexing_mode=1')) # specify parameters via map, definition are same as c++ version param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index b0fc28825..ac66b53d4 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1195,10 +1195,13 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, const char ***out_features); /*! - * \brief Calculate feature scores for tree models. + * \brief Calculate feature scores for tree models. When used on linear model, only the + * `weight` importance type is defined, and output scores is a row major matrix with shape + * [n_features, n_classes] for multi-class model. For tree model, out_n_feature is always + * equal to out_n_scores and has multiple definitions of importance type. * - * \param handle An instance of Booster - * \param json_config Parameters for computing scores. Accepted JSON keys are: + * \param handle An instance of Booster + * \param json_config Parameters for computing scores. Accepted JSON keys are: * - importance_type: A JSON string with following possible values: * * 'weight': the number of times a feature is used to split the data across all trees. * * 'gain': the average gain across all splits the feature is used in. @@ -1206,15 +1209,20 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, * * 'total_gain': the total gain across all splits the feature is used in. * * 'total_cover': the total coverage across all splits the feature is used in. * - feature_map: An optional JSON string with URI or path to the feature map file. + * - feature_names: An optional JSON array with string names for each feature. * - * \param out_length Length of output arrays. - * \param out_features An array of string as feature names, ordered the same as output scores. - * \param out_scores An array of floating point as feature scores. + * \param out_n_features Length of output feature names. + * \param out_features An array of string as feature names, ordered the same as output scores. + * \param out_dim Dimension of output feature scores. + * \param out_shape Shape of output feature scores with length of `out_dim`. + * \param out_scores An array of floating point as feature scores with shape of `out_shape`. * * \return 0 when success, -1 when failure happens */ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, - bst_ulong *out_length, - const char ***out_features, - float **out_scores); + bst_ulong *out_n_features, + char const ***out_features, + bst_ulong *out_dim, + bst_ulong const **out_shape, + float const **out_scores); #endif // XGBOOST_C_API_H_ diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index fde861f13..580cb52a5 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -184,9 +184,7 @@ class GradientBooster : public Model, public Configurable { virtual void FeatureScore(std::string const &importance_type, std::vector *features, - std::vector *scores) const { - LOG(FATAL) << "`feature_score` is not implemented for current booster."; - } + std::vector *scores) const = 0; /*! * \brief Whether the current booster uses GPU. */ diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index dbaf4b800..5bd6f913a 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace xgboost { /*! @@ -59,6 +60,13 @@ template class MatrixView { strides_[0] = shape[1]; strides_[1] = 1; } + MatrixView(std::vector *vec, std::array shape) + : device_{GenericParameter::kCpuId}, values_{*vec} { + CHECK_EQ(vec->size(), shape[0] * shape[1]); + std::copy(shape.cbegin(), shape.cend(), shape_); + strides_[0] = shape[1]; + strides_[1] = 1; + } MatrixView(HostDeviceVector> const *vec, std::array shape, int32_t device) : device_{device}, values_{InferValues(vec, device)} { diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 4df68c947..0b0e4cb0d 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014-2021 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -32,6 +32,9 @@ import org.junit.Test; * @author hzx */ public class BoosterImplTest { + private String train_uri = "../../demo/data/agaricus.txt.train?indexing_mode=1"; + private String test_uri = "../../demo/data/agaricus.txt.test?indexing_mode=1"; + public static class EvalError implements IEvaluation { @Override public String getMetric() { @@ -87,8 +90,8 @@ public class BoosterImplTest { @Test public void testBoosterBasic() throws XGBoostError, IOException { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); @@ -103,8 +106,8 @@ public class BoosterImplTest { @Test public void saveLoadModelWithPath() throws XGBoostError, IOException { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); IEvaluation eval = new EvalError(); Booster booster = trainBooster(trainMat, testMat); @@ -121,8 +124,8 @@ public class BoosterImplTest { @Test public void saveLoadModelWithStream() throws XGBoostError, IOException { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); @@ -310,8 +313,8 @@ public class BoosterImplTest { @Test public void testBoosterEarlyStop() throws XGBoostError, IOException { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -363,8 +366,8 @@ public class BoosterImplTest { @Test public void testQuantileHistoDepthWise() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -383,8 +386,8 @@ public class BoosterImplTest { @Test public void testQuantileHistoLossGuide() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -404,8 +407,8 @@ public class BoosterImplTest { @Test public void testQuantileHistoLossGuideMaxBin() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -425,8 +428,8 @@ public class BoosterImplTest { @Test public void testDumpModelJson() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] dump = booster.getModelDump("", false, "json"); @@ -441,8 +444,8 @@ public class BoosterImplTest { @Test public void testGetFeatureScore() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] featureNames = new String[126]; @@ -453,8 +456,8 @@ public class BoosterImplTest { @Test public void testGetFeatureImportanceGain() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] featureNames = new String[126]; @@ -465,8 +468,8 @@ public class BoosterImplTest { @Test public void testGetFeatureImportanceTotalGain() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] featureNames = new String[126]; @@ -477,8 +480,8 @@ public class BoosterImplTest { @Test public void testGetFeatureImportanceCover() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] featureNames = new String[126]; @@ -489,8 +492,8 @@ public class BoosterImplTest { @Test public void testGetFeatureImportanceTotalCover() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); String[] featureNames = new String[126]; @@ -501,7 +504,7 @@ public class BoosterImplTest { @Test public void testQuantileHistoDepthwiseMaxDepth() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix trainMat = new DMatrix(this.train_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -519,8 +522,8 @@ public class BoosterImplTest { @Test public void testQuantileHistoDepthwiseMaxDepthMaxBin() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Map paramMap = new HashMap() { { put("max_depth", 3); @@ -545,7 +548,7 @@ public class BoosterImplTest { @Test public void testCV() throws XGBoostError { //load train mat - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix trainMat = new DMatrix(this.train_uri); //set params Map param = new HashMap() { @@ -573,8 +576,8 @@ public class BoosterImplTest { */ @Test public void testTrainFromExistingModel() throws XGBoostError, IOException { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); IEvaluation eval = new EvalError(); Map paramMap = new HashMap() { @@ -624,8 +627,8 @@ public class BoosterImplTest { */ @Test public void testSetAndGetAttrs() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); booster.setAttr("testKey1", "testValue1"); @@ -654,10 +657,10 @@ public class BoosterImplTest { */ @Test public void testGetNumFeature() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + DMatrix trainMat = new DMatrix(this.train_uri); + DMatrix testMat = new DMatrix(this.test_uri); Booster booster = trainBooster(trainMat, testMat); - TestCase.assertEquals(booster.getNumFeature(), 127); + TestCase.assertEquals(booster.getNumFeature(), 126); } } diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 3f1806924..f4fe1b396 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2132,47 +2132,18 @@ class Booster(object): fmap = os.fspath(os.path.expanduser(fmap)) length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() - if self.feature_names is not None and fmap == '': - flen = len(self.feature_names) - - fname = from_pystr_to_cstr(self.feature_names) - - if self.feature_types is None: - # use quantitative as default - # {'q': quantitative, 'i': indicator} - ftype = from_pystr_to_cstr(['q'] * flen) - else: - ftype = from_pystr_to_cstr(self.feature_types) - _check_call(_LIB.XGBoosterDumpModelExWithFeatures( - self.handle, - ctypes.c_int(flen), - fname, - ftype, - ctypes.c_int(with_stats), - c_str(dump_format), - ctypes.byref(length), - ctypes.byref(sarr))) - else: - if fmap != '' and not os.path.exists(fmap): - raise ValueError("No such file: {0}".format(fmap)) - _check_call(_LIB.XGBoosterDumpModelEx(self.handle, - c_str(fmap), - ctypes.c_int(with_stats), - c_str(dump_format), - ctypes.byref(length), - ctypes.byref(sarr))) + _check_call(_LIB.XGBoosterDumpModelEx(self.handle, + c_str(fmap), + ctypes.c_int(with_stats), + c_str(dump_format), + ctypes.byref(length), + ctypes.byref(sarr))) res = from_cstr_to_pystr(sarr, length) return res def get_fscore(self, fmap=''): """Get feature importance of each feature. - .. note:: Feature importance is defined only for tree boosters - - Feature importance is only defined when the decision tree model is chosen as base - learner (`booster=gbtree`). It is not defined for other base learner types, such - as linear learners (`booster=gblinear`). - .. note:: Zero-importance features will not be included Keep in mind that this function does not include zero-importance feature, i.e. @@ -2190,7 +2161,7 @@ class Booster(object): self, fmap: os.PathLike = '', importance_type: str = 'weight' ) -> Dict[str, float]: """Get feature importance of each feature. - Importance type can be defined as: + For tree model Importance type can be defined as: * 'weight': the number of times a feature is used to split the data across all trees. * 'gain': the average gain across all splits the feature is used in. @@ -2198,11 +2169,15 @@ class Booster(object): * 'total_gain': the total gain across all splits the feature is used in. * 'total_cover': the total coverage across all splits the feature is used in. - .. note:: Feature importance is defined only for tree boosters + .. note:: - Feature importance is only defined when the decision tree model is chosen as - base learner (`booster=gbtree` or `booster=dart`). It is not defined for other - base learner types, such as linear learners (`booster=gblinear`). + For linear model, only "weight" is defined and it's the normalized coefficients + without bias. + + .. note:: Zero-importance features will not be included + + Keep in mind that this function does not include zero-importance feature, i.e. + those features that have not been used in any split conditions. Parameters ---------- @@ -2213,7 +2188,9 @@ class Booster(object): Returns ------- - A map between feature names and their scores. + A map between feature names and their scores. When `gblinear` is used for + multi-class classification the scores for each feature is a list with length + `n_classes`, otherwise they're scalars. """ fmap = os.fspath(os.path.expanduser(fmap)) args = from_pystr_to_cstr( @@ -2221,21 +2198,31 @@ class Booster(object): ) features = ctypes.POINTER(ctypes.c_char_p)() scores = ctypes.POINTER(ctypes.c_float)() - length = c_bst_ulong() + n_out_features = c_bst_ulong() + out_dim = c_bst_ulong() + shape = ctypes.POINTER(c_bst_ulong)() + _check_call( _LIB.XGBoosterFeatureScore( self.handle, args, - ctypes.byref(length), + ctypes.byref(n_out_features), ctypes.byref(features), - ctypes.byref(scores) + ctypes.byref(out_dim), + ctypes.byref(shape), + ctypes.byref(scores), ) ) - features_arr = from_cstr_to_pystr(features, length) - scores_arr = ctypes2numpy(scores, length.value, np.float32) + features_arr = from_cstr_to_pystr(features, n_out_features) + scores_arr = _prediction_output(shape, out_dim, scores, False) + results = {} - for feat, score in zip(features_arr, scores_arr): - results[feat] = float(score) + if len(scores_arr.shape) > 1 and scores_arr.shape[1] > 1: + for feat, score in zip(features_arr, scores_arr): + results[feat] = [float(s) for s in score] + else: + for feat, score in zip(features_arr, scores_arr): + results[feat] = float(score) return results def trees_to_dataframe(self, fmap=''): diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 76e5d7e9c..82fbb3eef 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -156,9 +156,14 @@ __model_doc = f''' [2, 3, 4]], where each inner list is a group of indices of features that are allowed to interact with each other. See tutorial for more information - importance_type: string, default "gain" + importance_type: Optional[str] The feature importance type for the feature_importances\\_ property: - either "gain", "weight", "cover", "total_gain" or "total_cover". + + * For tree model, it's either "gain", "weight", "cover", "total_gain" or + "total_cover". + * For linear model, only "weight" is defined and it's the normalized coefficients + without bias. + gpu_id : Optional[int] Device ordinal. validate_parameters : Optional[bool] @@ -382,7 +387,7 @@ class XGBModel(XGBModelBase): num_parallel_tree: Optional[int] = None, monotone_constraints: Optional[Union[Dict[str, int], str]] = None, interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None, - importance_type: str = "gain", + importance_type: Optional[str] = None, gpu_id: Optional[int] = None, validate_parameters: Optional[bool] = None, predictor: Optional[str] = None, @@ -991,29 +996,26 @@ class XGBModel(XGBModelBase): @property def feature_importances_(self) -> np.ndarray: """ - Feature importances property - - .. note:: Feature importance is defined only for tree boosters - - Feature importance is only defined when the decision tree model is chosen as base - learner (`booster=gbtree`). It is not defined for other base learner types, such - as linear learners (`booster=gblinear`). + Feature importances property, return depends on `importance_type` parameter. Returns ------- - feature_importances_ : array of shape ``[n_features]`` + feature_importances_ : array of shape ``[n_features]`` except for multi-class + linear model, which returns an array with shape `(n_features, n_classes)` """ - if self.get_params()['booster'] not in {'gbtree', 'dart'}: - raise AttributeError( - 'Feature importance is not defined for Booster type {}' - .format(self.booster)) b: Booster = self.get_booster() - score = b.get_score(importance_type=self.importance_type) + + def dft() -> str: + return "weight" if self.booster == "gblinear" else "gain" + score = b.get_score( + importance_type=self.importance_type if self.importance_type else dft() + ) if b.feature_names is None: feature_names = ["f{0}".format(i) for i in range(self.n_features_in_)] else: feature_names = b.feature_names + # gblinear returns all features so the `get` in next line is only for gbtree. all_features = [score.get(f, 0.) for f in feature_names] all_features_arr = np.array(all_features, dtype=np.float32) total = all_features_arr.sum() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5354e78b1..c8be90caf 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -927,14 +927,17 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, API_END(); } -inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap, +inline void XGBoostDumpModelImpl(BoosterHandle handle, FeatureMap* fmap, int with_stats, const char *format, xgboost::bst_ulong *len, const char ***out_models) { auto *bst = static_cast(handle); + bst->Configure(); + GenerateFeatureMap(bst, {}, bst->GetNumFeature(), fmap); + std::vector& str_vecs = bst->GetThreadLocal().ret_vec_str; std::vector& charp_vecs = bst->GetThreadLocal().ret_vec_charp; - str_vecs = bst->DumpModel(fmap, with_stats != 0, format); + str_vecs = bst->DumpModel(*fmap, with_stats != 0, format); charp_vecs.resize(str_vecs.size()); for (size_t i = 0; i < str_vecs.size(); ++i) { charp_vecs[i] = str_vecs[i].c_str(); @@ -962,14 +965,9 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, const char*** out_models) { API_BEGIN(); CHECK_HANDLE(); - FeatureMap featmap; - if (strlen(fmap) != 0) { - std::unique_ptr fs( - dmlc::Stream::Create(fmap, "r")); - dmlc::istream is(fs.get()); - featmap.LoadText(is); - } - XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models); + std::string uri{fmap}; + FeatureMap featmap = LoadFeatureMap(uri); + XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models); API_END(); } @@ -980,8 +978,8 @@ XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, int with_stats, xgboost::bst_ulong* len, const char*** out_models) { - return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats, - "text", len, out_models); + return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, + with_stats, "text", len, out_models); } XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, @@ -998,7 +996,7 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, for (int i = 0; i < fnum; ++i) { featmap.PushBack(i, fname[i], ftype[i]); } - XGBoostDumpModelImpl(handle, featmap, with_stats, format, len, out_models); + XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models); API_END(); } @@ -1098,11 +1096,12 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, API_END(); } -XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, - const char *json_config, - xgboost::bst_ulong* out_length, - const char ***out_features, - float **out_scores) { +XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config, + xgboost::bst_ulong *out_n_features, + char const ***out_features, + bst_ulong *out_dim, + bst_ulong const **out_shape, + float const **out_scores) { API_BEGIN(); CHECK_HANDLE(); auto *learner = static_cast(handle); @@ -1113,14 +1112,17 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, feature_map_uri = get(config["feature_map"]); } FeatureMap feature_map = LoadFeatureMap(feature_map_uri); + std::vector custom_feature_names; + if (!IsA(config["feature_names"])) { + custom_feature_names = get(config["feature_names"]); + } auto& scores = learner->GetThreadLocal().ret_vec_float; std::vector features; learner->CalcFeatureScore(importance, &features, &scores); auto n_features = learner->GetNumFeature(); - GenerateFeatureMap(learner, n_features, &feature_map); - CHECK_LE(features.size(), n_features); + GenerateFeatureMap(learner, custom_feature_names, n_features, &feature_map); auto& feature_names = learner->GetThreadLocal().ret_vec_str; feature_names.resize(features.size()); @@ -1131,10 +1133,24 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, feature_names[i] = feature_map.Name(features[i]); feature_names_c[i] = feature_names[i].data(); } + *out_n_features = feature_names.size(); - CHECK_EQ(scores.size(), features.size()); - CHECK_EQ(scores.size(), feature_names.size()); - *out_length = scores.size(); + CHECK_LE(features.size(), scores.size()); + auto &shape = learner->GetThreadLocal().prediction_shape; + if (scores.size() > features.size()) { + // Linear model multi-class model + CHECK_EQ(scores.size() % features.size(), 0ul); + auto n_classes = scores.size() / features.size(); + *out_dim = 2; + shape = {n_features, n_classes}; + } else { + CHECK_EQ(features.size(), scores.size()); + *out_dim = 1; + shape.resize(1); + shape.front() = scores.size(); + } + + *out_shape = dmlc::BeginPtr(shape); *out_scores = scores.data(); *out_features = dmlc::BeginPtr(feature_names_c); API_END(); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 7c1538cb1..b044c6879 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -194,8 +194,8 @@ inline FeatureMap LoadFeatureMap(std::string const& uri) { return feat; } -// FIXME(jiamingy): Use this for model dump. inline void GenerateFeatureMap(Learner const *learner, + std::vector const &custom_feature_names, size_t n_features, FeatureMap *out_feature_map) { auto &feature_map = *out_feature_map; auto maybe = [&](std::vector const &values, size_t i, @@ -205,15 +205,31 @@ inline void GenerateFeatureMap(Learner const *learner, if (feature_map.Size() == 0) { // Use the feature names and types from booster. std::vector feature_names; - learner->GetFeatureNames(&feature_names); + // priority: + // 1. feature map. + // 2. customized feature name. + // 3. from booster + // 4. default feature name. + if (!custom_feature_names.empty()) { + CHECK_EQ(custom_feature_names.size(), n_features) + << "Incorrect number of feature names."; + feature_names.resize(custom_feature_names.size()); + std::transform(custom_feature_names.begin(), custom_feature_names.end(), + feature_names.begin(), + [](Json const &name) { return get(name); }); + } else { + learner->GetFeatureNames(&feature_names); + } if (!feature_names.empty()) { CHECK_EQ(feature_names.size(), n_features) << "Incorrect number of feature names."; } + std::vector feature_types; learner->GetFeatureTypes(&feature_types); if (!feature_types.empty()) { CHECK_EQ(feature_types.size(), n_features) << "Incorrect number of feature types."; } + for (size_t i = 0; i < n_features; ++i) { feature_map.PushBack( i, diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 2cf2edf7f..a2dc1b67e 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -12,6 +12,7 @@ #include #include #include +#include #include "xgboost/gbm.h" #include "xgboost/json.h" @@ -19,6 +20,7 @@ #include "xgboost/linear_updater.h" #include "xgboost/logging.h" #include "xgboost/learner.h" +#include "xgboost/linalg.h" #include "gblinear_model.h" #include "../common/timer.h" @@ -219,6 +221,26 @@ class GBLinear : public GradientBooster { return model_.DumpModel(fmap, with_stats, format); } + void FeatureScore(std::string const &importance_type, + std::vector *out_features, + std::vector *out_scores) const override { + CHECK(!model_.weight.empty()) << "Model is not initialized"; + CHECK_EQ(importance_type, "weight") + << "gblinear only has `weight` defined for feature importance."; + out_features->resize(this->learner_model_param_->num_feature, 0); + std::iota(out_features->begin(), out_features->end(), 0); + // Don't include the bias term in the feature importance scores + // The bias is the last weight + out_scores->resize(model_.weight.size() - learner_model_param_->num_output_group, 0); + auto n_groups = learner_model_param_->num_output_group; + MatrixView scores{out_scores, {learner_model_param_->num_feature, n_groups}}; + for (size_t i = 0; i < learner_model_param_->num_feature; ++i) { + for (bst_group_t g = 0; g < n_groups; ++g) { + scores(i, g) = model_[i][g]; + } + } + } + bool UseGPU() const override { if (param_.updater == "gpu_coord_descent") { return true; diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 872e22fd4..958ce00f8 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -325,16 +325,19 @@ class GBTree : public GradientBooster { add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) { gain_map[split] = split_counts[split]; }); - } - if (importance_type == "gain" || importance_type == "total_gain") { + } else if (importance_type == "gain" || importance_type == "total_gain") { add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) { gain_map[split] += p_tree->Stat(nidx).loss_chg; }); - } - if (importance_type == "cover" || importance_type == "total_cover") { + } else if (importance_type == "cover" || importance_type == "total_cover") { add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) { gain_map[split] += p_tree->Stat(nidx).sum_hess; }); + } else { + LOG(FATAL) + << "Unknown feature importance type, expected one of: " + << R"({"weight", "total_gain", "total_cover", "gain", "cover"}, got: )" + << importance_type; } if (importance_type == "gain" || importance_type == "cover") { for (size_t i = 0; i < gain_map.size(); ++i) { diff --git a/src/learner.cc b/src/learner.cc index a3086aa72..15adf95bf 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1197,23 +1197,6 @@ class LearnerImpl : public LearnerIO { std::vector *features, std::vector *scores) override { this->Configure(); - std::vector allowed_importance_type = { - "weight", "total_gain", "total_cover", "gain", "cover" - }; - if (std::find(allowed_importance_type.begin(), - allowed_importance_type.end(), - importance_type) == allowed_importance_type.end()) { - std::stringstream ss; - ss << "importance_type mismatch, got: " << importance_type - << "`, expected one of "; - for (size_t i = 0; i < allowed_importance_type.size(); ++i) { - ss << "`" << allowed_importance_type[i] << "`"; - if (i != allowed_importance_type.size() - 1) { - ss << ", "; - } - } - LOG(FATAL) << ss.str(); - } gbm_->FeatureScore(importance_type, features, scores); } diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index bddb4458b..f95043fed 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -154,6 +154,9 @@ class TestBasic: dump4j = json.loads(dump4[0]) assert 'gain' in dump4j, "Expected 'gain' to be dumped in JSON." + with pytest.raises(ValueError): + bst.get_dump(fmap="foo") + def test_feature_score(self): rng = np.random.RandomState(0) data = rng.randn(100, 2) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 12297ceec..d44d0e3af 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -211,6 +211,7 @@ def test_feature_importances_weight(): digits = load_digits(n_class=2) y = digits['target'] X = digits['data'] + xgb_model = xgb.XGBClassifier(random_state=0, tree_method="exact", learning_rate=0.1, @@ -241,6 +242,33 @@ def test_feature_importances_weight(): importance_type="weight").fit(X, y) np.testing.assert_almost_equal(xgb_model.feature_importances_, exp) + with pytest.raises(ValueError): + xgb_model.set_params(importance_type="foo") + xgb_model.feature_importances_ + + X, y = load_digits(n_class=3, return_X_y=True) + + cls = xgb.XGBClassifier(booster="gblinear", n_estimators=4) + cls.fit(X, y) + assert cls.feature_importances_.shape[0] == X.shape[1] + assert cls.feature_importances_.shape[1] == 3 + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "model.json") + cls.save_model(path) + with open(path, "r") as fd: + model = json.load(fd) + weights = np.array( + model["learner"]["gradient_booster"]["model"]["weights"] + ).reshape((cls.n_features_in_ + 1, 3)) + weights = weights[:-1, ...] + np.testing.assert_allclose( + weights / weights.sum(), cls.feature_importances_, rtol=1e-6 + ) + + with pytest.raises(ValueError): + cls.set_params(importance_type="cover") + cls.feature_importances_ + @pytest.mark.skipif(**tm.no_pandas()) def test_feature_importances_gain():