[jvm-packages] Updates to Java Booster to support other feature importance measures (#3801)

* Updates to Booster to support other feature importances

* Add returns for Java methods

* Pass Scala style checks

* Pass Java style checks

* Fix indents

* Use class instead of enum

* Return map string double

* A no longer broken build, thanks to mvn package local build

* Add a unit test to increase code coverage back

* Address code review on main code

* Add more unit tests for different feature importance scores

* Address more CR
This commit is contained in:
Shayak Banerjee 2019-01-02 04:13:14 -05:00 committed by Nan Zhu
parent 1f022929f4
commit 431c850c03
3 changed files with 195 additions and 16 deletions

View File

@ -16,9 +16,12 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.KryoSerializable;
@ -395,6 +398,25 @@ public class Booster implements Serializable, KryoSerializable {
return modelInfos[0];
}
/**
* Supported feature importance types
*
* WEIGHT = Number of nodes that a feature was used to determine a split
* GAIN = Average information gain per split for a feature
* COVER = Average cover per split for a feature
* TOTAL_GAIN = Total information gain over all splits of a feature
* TOTAL_COVER = Total cover over all splits of a feature
*/
public static class FeatureImportanceType {
public static final String WEIGHT = "weight";
public static final String GAIN = "gain";
public static final String COVER = "cover";
public static final String TOTAL_GAIN = "total_gain";
public static final String TOTAL_COVER = "total_cover";
public static final Set<String> ACCEPTED_TYPES = new HashSet<>(
Arrays.asList(WEIGHT, GAIN, COVER, TOTAL_GAIN, TOTAL_COVER));
}
/**
* Get importance of each feature with specified feature names.
*
@ -403,6 +425,28 @@ public class Booster implements Serializable, KryoSerializable {
*/
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
String[] modelInfos = getModelDump(featureNames, false);
return getFeatureWeightsFromModel(modelInfos);
}
/**
* Get importance of each feature
*
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getModelDump(featureMap, false);
return getFeatureWeightsFromModel(modelInfos);
}
/**
* Get the importance of each feature based purely on weights (number of splits)
*
* @return featureScoreMap key: feature index,
* value: feature importance score based on weight
* @throws XGBoostError native error
*/
private Map<String, Integer> getFeatureWeightsFromModel(String[] modelInfos) throws XGBoostError {
Map<String, Integer> featureScore = new HashMap<>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
@ -423,30 +467,91 @@ public class Booster implements Serializable, KryoSerializable {
}
/**
* Get importance of each feature
* Get the feature importances for gain or cover (average or total)
*
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
* @return featureImportanceMap key: feature index,
* values: feature importance score based on gain or cover
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getModelDump(featureMap, false);
Map<String, Integer> featureScore = new HashMap<>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
public Map<String, Double> getScore(
String[] featureNames, String importanceType) throws XGBoostError {
String[] modelInfos = getModelDump(featureNames, true);
return getFeatureImportanceFromModel(modelInfos, importanceType);
}
/**
* Get the feature importances for gain or cover (average or total), with feature names
*
* @return featureImportanceMap key: feature name,
* values: feature importance score based on gain or cover
* @throws XGBoostError native error
*/
public Map<String, Double> getScore(
String featureMap, String importanceType) throws XGBoostError {
String[] modelInfos = getModelDump(featureMap, true);
return getFeatureImportanceFromModel(modelInfos, importanceType);
}
/**
* Get the importance of each feature based on information gain or cover
*
* @return featureImportanceMap key: feature index, value: feature importance score
* based on information gain or cover
* @throws XGBoostError native error
*/
private Map<String, Double> getFeatureImportanceFromModel(
String[] modelInfos, String importanceType) throws XGBoostError {
if (!FeatureImportanceType.ACCEPTED_TYPES.contains(importanceType)) {
throw new AssertionError(String.format("Importance type %s is not supported",
importanceType));
}
Map<String, Double> importanceMap = new HashMap<>();
Map<String, Double> weightMap = new HashMap<>();
if (importanceType == FeatureImportanceType.WEIGHT) {
Map<String, Integer> importanceWeights = getFeatureWeightsFromModel(modelInfos);
for (String feature: importanceWeights.keySet()) {
importanceMap.put(feature, new Double(importanceWeights.get(feature)));
}
return importanceMap;
}
/* Each split in the tree has this text form:
"0:[f28<-9.53674316e-07] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25"
So the line has to be split according to whether cover or gain is desired */
String splitter = "gain=";
if (importanceType == FeatureImportanceType.COVER
|| importanceType == FeatureImportanceType.TOTAL_COVER) {
splitter = "cover=";
}
for (String tree: modelInfos) {
for (String node: tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
String[] fidWithImportance = array[1].split("\\]");
// Extract gain or cover from string after closing bracket
Double importance = Double.parseDouble(
fidWithImportance[1].split(splitter)[1].split(",")[0]
);
String fid = fidWithImportance[0].split("<")[0];
if (importanceMap.containsKey(fid)) {
importanceMap.put(fid, importance + importanceMap.get(fid));
weightMap.put(fid, 1d + weightMap.get(fid));
} else {
featureScore.put(fid, 1);
importanceMap.put(fid, importance);
weightMap.put(fid, 1d);
}
}
}
return featureScore;
/* By default we calculate total gain and total cover.
Divide by the number of nodes per feature to get gain / cover */
if (importanceType == FeatureImportanceType.COVER
|| importanceType == FeatureImportanceType.GAIN) {
for (String fid: importanceMap.keySet()) {
importanceMap.put(fid, importanceMap.get(fid)/weightMap.get(fid));
}
}
return importanceMap;
}
/**

View File

@ -204,7 +204,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
/**
* Get importance of each feature
* Get importance of each feature based on weight only (number of splits)
*
* @return featureScoreMap key: feature index, value: feature importance score
*/
@ -214,7 +214,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
}
/**
* Get importance of each feature with specified feature names.
* Get importance of each feature based on weight only
* (number of splits), with specified feature names.
*
* @return featureScoreMap key: feature name, value: feature importance score
*/
@ -223,6 +224,31 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.getFeatureScore(featureNames).asScala
}
/**
* Get importance of each feature based on information gain or cover
* Supported: ["gain, "cover", "total_gain", "total_cover"]
*
* @return featureScoreMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getScore(featureMap: String, importanceType: String): Map[String, Double] = {
Map(booster.getScore(featureMap, importanceType)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}
/**
* Get importance of each feature based on information gain or cover
* , with specified feature names.
* Supported: ["gain, "cover", "total_gain", "total_cover"]
*
* @return featureScoreMap key: feature name, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = {
Map(booster.getScore(featureNames, importanceType)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}
def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = {

View File

@ -473,7 +473,7 @@ public class BoosterImplTest {
}
@Test
public void testGetFeatureImportance() throws XGBoostError {
public void testGetFeatureScore() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
@ -484,6 +484,54 @@ public class BoosterImplTest {
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
}
@Test
public void testGetFeatureImportanceGain() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
String[] featureNames = new String[126];
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
Map<String, Double> scoreMap = booster.getScore(featureNames, "gain");
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
}
@Test
public void testGetFeatureImportanceTotalGain() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
String[] featureNames = new String[126];
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_gain");
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
}
@Test
public void testGetFeatureImportanceCover() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
String[] featureNames = new String[126];
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
Map<String, Double> scoreMap = booster.getScore(featureNames, "cover");
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
}
@Test
public void testGetFeatureImportanceTotalCover() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
String[] featureNames = new String[126];
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_cover");
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
}
@Test
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");