diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 851a1133e..f6913ffc7 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -16,9 +16,12 @@ package ml.dmlc.xgboost4j.java; import java.io.*; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import com.esotericsoftware.kryo.Kryo; import com.esotericsoftware.kryo.KryoSerializable; @@ -395,6 +398,25 @@ public class Booster implements Serializable, KryoSerializable { return modelInfos[0]; } + /** + * Supported feature importance types + * + * WEIGHT = Number of nodes that a feature was used to determine a split + * GAIN = Average information gain per split for a feature + * COVER = Average cover per split for a feature + * TOTAL_GAIN = Total information gain over all splits of a feature + * TOTAL_COVER = Total cover over all splits of a feature + */ + public static class FeatureImportanceType { + public static final String WEIGHT = "weight"; + public static final String GAIN = "gain"; + public static final String COVER = "cover"; + public static final String TOTAL_GAIN = "total_gain"; + public static final String TOTAL_COVER = "total_cover"; + public static final Set ACCEPTED_TYPES = new HashSet<>( + Arrays.asList(WEIGHT, GAIN, COVER, TOTAL_GAIN, TOTAL_COVER)); + } + /** * Get importance of each feature with specified feature names. * @@ -403,6 +425,28 @@ public class Booster implements Serializable, KryoSerializable { */ public Map getFeatureScore(String[] featureNames) throws XGBoostError { String[] modelInfos = getModelDump(featureNames, false); + return getFeatureWeightsFromModel(modelInfos); + } + + /** + * Get importance of each feature + * + * @return featureScoreMap key: feature index, value: feature importance score, can be nill + * @throws XGBoostError native error + */ + public Map getFeatureScore(String featureMap) throws XGBoostError { + String[] modelInfos = getModelDump(featureMap, false); + return getFeatureWeightsFromModel(modelInfos); + } + + /** + * Get the importance of each feature based purely on weights (number of splits) + * + * @return featureScoreMap key: feature index, + * value: feature importance score based on weight + * @throws XGBoostError native error + */ + private Map getFeatureWeightsFromModel(String[] modelInfos) throws XGBoostError { Map featureScore = new HashMap<>(); for (String tree : modelInfos) { for (String node : tree.split("\n")) { @@ -423,30 +467,91 @@ public class Booster implements Serializable, KryoSerializable { } /** - * Get importance of each feature + * Get the feature importances for gain or cover (average or total) * - * @return featureScoreMap key: feature index, value: feature importance score, can be nill + * @return featureImportanceMap key: feature index, + * values: feature importance score based on gain or cover * @throws XGBoostError native error */ - public Map getFeatureScore(String featureMap) throws XGBoostError { - String[] modelInfos = getModelDump(featureMap, false); - Map featureScore = new HashMap<>(); - for (String tree : modelInfos) { - for (String node : tree.split("\n")) { + public Map getScore( + String[] featureNames, String importanceType) throws XGBoostError { + String[] modelInfos = getModelDump(featureNames, true); + return getFeatureImportanceFromModel(modelInfos, importanceType); + } + + /** + * Get the feature importances for gain or cover (average or total), with feature names + * + * @return featureImportanceMap key: feature name, + * values: feature importance score based on gain or cover + * @throws XGBoostError native error + */ + public Map getScore( + String featureMap, String importanceType) throws XGBoostError { + String[] modelInfos = getModelDump(featureMap, true); + return getFeatureImportanceFromModel(modelInfos, importanceType); + } + + /** + * Get the importance of each feature based on information gain or cover + * + * @return featureImportanceMap key: feature index, value: feature importance score + * based on information gain or cover + * @throws XGBoostError native error + */ + private Map getFeatureImportanceFromModel( + String[] modelInfos, String importanceType) throws XGBoostError { + if (!FeatureImportanceType.ACCEPTED_TYPES.contains(importanceType)) { + throw new AssertionError(String.format("Importance type %s is not supported", + importanceType)); + } + Map importanceMap = new HashMap<>(); + Map weightMap = new HashMap<>(); + if (importanceType == FeatureImportanceType.WEIGHT) { + Map importanceWeights = getFeatureWeightsFromModel(modelInfos); + for (String feature: importanceWeights.keySet()) { + importanceMap.put(feature, new Double(importanceWeights.get(feature))); + } + return importanceMap; + } + /* Each split in the tree has this text form: + "0:[f28<-9.53674316e-07] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25" + So the line has to be split according to whether cover or gain is desired */ + String splitter = "gain="; + if (importanceType == FeatureImportanceType.COVER + || importanceType == FeatureImportanceType.TOTAL_COVER) { + splitter = "cover="; + } + for (String tree: modelInfos) { + for (String node: tree.split("\n")) { String[] array = node.split("\\["); if (array.length == 1) { continue; } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if (featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); + String[] fidWithImportance = array[1].split("\\]"); + // Extract gain or cover from string after closing bracket + Double importance = Double.parseDouble( + fidWithImportance[1].split(splitter)[1].split(",")[0] + ); + String fid = fidWithImportance[0].split("<")[0]; + if (importanceMap.containsKey(fid)) { + importanceMap.put(fid, importance + importanceMap.get(fid)); + weightMap.put(fid, 1d + weightMap.get(fid)); } else { - featureScore.put(fid, 1); + importanceMap.put(fid, importance); + weightMap.put(fid, 1d); } } } - return featureScore; + /* By default we calculate total gain and total cover. + Divide by the number of nodes per feature to get gain / cover */ + if (importanceType == FeatureImportanceType.COVER + || importanceType == FeatureImportanceType.GAIN) { + for (String fid: importanceMap.keySet()) { + importanceMap.put(fid, importanceMap.get(fid)/weightMap.get(fid)); + } + } + return importanceMap; } /** diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index 60b98f867..f86ab8e18 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -204,7 +204,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) /** - * Get importance of each feature + * Get importance of each feature based on weight only (number of splits) * * @return featureScoreMap key: feature index, value: feature importance score */ @@ -214,7 +214,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) } /** - * Get importance of each feature with specified feature names. + * Get importance of each feature based on weight only + * (number of splits), with specified feature names. * * @return featureScoreMap key: feature name, value: feature importance score */ @@ -223,6 +224,31 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) booster.getFeatureScore(featureNames).asScala } + /** + * Get importance of each feature based on information gain or cover + * Supported: ["gain, "cover", "total_gain", "total_cover"] + * + * @return featureScoreMap key: feature index, value: feature importance score + */ + @throws(classOf[XGBoostError]) + def getScore(featureMap: String, importanceType: String): Map[String, Double] = { + Map(booster.getScore(featureMap, importanceType) + .asScala.mapValues(_.doubleValue).toSeq: _*) + } + + /** + * Get importance of each feature based on information gain or cover + * , with specified feature names. + * Supported: ["gain, "cover", "total_gain", "total_cover"] + * + * @return featureScoreMap key: feature name, value: feature importance score + */ + @throws(classOf[XGBoostError]) + def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = { + Map(booster.getScore(featureNames, importanceType) + .asScala.mapValues(_.doubleValue).toSeq: _*) + } + def getVersion: Int = booster.getVersion def toByteArray: Array[Byte] = { diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 74798a627..5ef2db049 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -473,7 +473,7 @@ public class BoosterImplTest { } @Test - public void testGetFeatureImportance() throws XGBoostError { + public void testGetFeatureScore() throws XGBoostError { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); @@ -484,6 +484,54 @@ public class BoosterImplTest { for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); } + @Test + public void testGetFeatureImportanceGain() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + Map scoreMap = booster.getScore(featureNames, "gain"); + for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + } + + @Test + public void testGetFeatureImportanceTotalGain() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + Map scoreMap = booster.getScore(featureNames, "total_gain"); + for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + } + + @Test + public void testGetFeatureImportanceCover() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + Map scoreMap = booster.getScore(featureNames, "cover"); + for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + } + + @Test + public void testGetFeatureImportanceTotalCover() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + Map scoreMap = booster.getScore(featureNames, "total_cover"); + for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); + } + @Test public void testFastHistoDepthwiseMaxDepth() throws XGBoostError { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");