[jvm-packages] Updates to Java Booster to support other feature importance measures (#3801)
* Updates to Booster to support other feature importances * Add returns for Java methods * Pass Scala style checks * Pass Java style checks * Fix indents * Use class instead of enum * Return map string double * A no longer broken build, thanks to mvn package local build * Add a unit test to increase code coverage back * Address code review on main code * Add more unit tests for different feature importance scores * Address more CR
This commit is contained in:
parent
1f022929f4
commit
431c850c03
@ -16,9 +16,12 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo;
|
||||
import com.esotericsoftware.kryo.KryoSerializable;
|
||||
@ -395,6 +398,25 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Supported feature importance types
|
||||
*
|
||||
* WEIGHT = Number of nodes that a feature was used to determine a split
|
||||
* GAIN = Average information gain per split for a feature
|
||||
* COVER = Average cover per split for a feature
|
||||
* TOTAL_GAIN = Total information gain over all splits of a feature
|
||||
* TOTAL_COVER = Total cover over all splits of a feature
|
||||
*/
|
||||
public static class FeatureImportanceType {
|
||||
public static final String WEIGHT = "weight";
|
||||
public static final String GAIN = "gain";
|
||||
public static final String COVER = "cover";
|
||||
public static final String TOTAL_GAIN = "total_gain";
|
||||
public static final String TOTAL_COVER = "total_cover";
|
||||
public static final Set<String> ACCEPTED_TYPES = new HashSet<>(
|
||||
Arrays.asList(WEIGHT, GAIN, COVER, TOTAL_GAIN, TOTAL_COVER));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature with specified feature names.
|
||||
*
|
||||
@ -403,6 +425,28 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureNames, false);
|
||||
return getFeatureWeightsFromModel(modelInfos);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature
|
||||
*
|
||||
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureMap, false);
|
||||
return getFeatureWeightsFromModel(modelInfos);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the importance of each feature based purely on weights (number of splits)
|
||||
*
|
||||
* @return featureScoreMap key: feature index,
|
||||
* value: feature importance score based on weight
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private Map<String, Integer> getFeatureWeightsFromModel(String[] modelInfos) throws XGBoostError {
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
@ -423,30 +467,91 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature
|
||||
* Get the feature importances for gain or cover (average or total)
|
||||
*
|
||||
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
|
||||
* @return featureImportanceMap key: feature index,
|
||||
* values: feature importance score based on gain or cover
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
public Map<String, Double> getScore(
|
||||
String[] featureNames, String importanceType) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureNames, true);
|
||||
return getFeatureImportanceFromModel(modelInfos, importanceType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the feature importances for gain or cover (average or total), with feature names
|
||||
*
|
||||
* @return featureImportanceMap key: feature name,
|
||||
* values: feature importance score based on gain or cover
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Double> getScore(
|
||||
String featureMap, String importanceType) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureMap, true);
|
||||
return getFeatureImportanceFromModel(modelInfos, importanceType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the importance of each feature based on information gain or cover
|
||||
*
|
||||
* @return featureImportanceMap key: feature index, value: feature importance score
|
||||
* based on information gain or cover
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
private Map<String, Double> getFeatureImportanceFromModel(
|
||||
String[] modelInfos, String importanceType) throws XGBoostError {
|
||||
if (!FeatureImportanceType.ACCEPTED_TYPES.contains(importanceType)) {
|
||||
throw new AssertionError(String.format("Importance type %s is not supported",
|
||||
importanceType));
|
||||
}
|
||||
Map<String, Double> importanceMap = new HashMap<>();
|
||||
Map<String, Double> weightMap = new HashMap<>();
|
||||
if (importanceType == FeatureImportanceType.WEIGHT) {
|
||||
Map<String, Integer> importanceWeights = getFeatureWeightsFromModel(modelInfos);
|
||||
for (String feature: importanceWeights.keySet()) {
|
||||
importanceMap.put(feature, new Double(importanceWeights.get(feature)));
|
||||
}
|
||||
return importanceMap;
|
||||
}
|
||||
/* Each split in the tree has this text form:
|
||||
"0:[f28<-9.53674316e-07] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25"
|
||||
So the line has to be split according to whether cover or gain is desired */
|
||||
String splitter = "gain=";
|
||||
if (importanceType == FeatureImportanceType.COVER
|
||||
|| importanceType == FeatureImportanceType.TOTAL_COVER) {
|
||||
splitter = "cover=";
|
||||
}
|
||||
for (String tree: modelInfos) {
|
||||
for (String node: tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
String[] fidWithImportance = array[1].split("\\]");
|
||||
// Extract gain or cover from string after closing bracket
|
||||
Double importance = Double.parseDouble(
|
||||
fidWithImportance[1].split(splitter)[1].split(",")[0]
|
||||
);
|
||||
String fid = fidWithImportance[0].split("<")[0];
|
||||
if (importanceMap.containsKey(fid)) {
|
||||
importanceMap.put(fid, importance + importanceMap.get(fid));
|
||||
weightMap.put(fid, 1d + weightMap.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
importanceMap.put(fid, importance);
|
||||
weightMap.put(fid, 1d);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
/* By default we calculate total gain and total cover.
|
||||
Divide by the number of nodes per feature to get gain / cover */
|
||||
if (importanceType == FeatureImportanceType.COVER
|
||||
|| importanceType == FeatureImportanceType.GAIN) {
|
||||
for (String fid: importanceMap.keySet()) {
|
||||
importanceMap.put(fid, importanceMap.get(fid)/weightMap.get(fid));
|
||||
}
|
||||
}
|
||||
return importanceMap;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -204,7 +204,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
|
||||
|
||||
/**
|
||||
* Get importance of each feature
|
||||
* Get importance of each feature based on weight only (number of splits)
|
||||
*
|
||||
* @return featureScoreMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@ -214,7 +214,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature with specified feature names.
|
||||
* Get importance of each feature based on weight only
|
||||
* (number of splits), with specified feature names.
|
||||
*
|
||||
* @return featureScoreMap key: feature name, value: feature importance score
|
||||
*/
|
||||
@ -223,6 +224,31 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.getFeatureScore(featureNames).asScala
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature based on information gain or cover
|
||||
* Supported: ["gain, "cover", "total_gain", "total_cover"]
|
||||
*
|
||||
* @return featureScoreMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getScore(featureMap: String, importanceType: String): Map[String, Double] = {
|
||||
Map(booster.getScore(featureMap, importanceType)
|
||||
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature based on information gain or cover
|
||||
* , with specified feature names.
|
||||
* Supported: ["gain, "cover", "total_gain", "total_cover"]
|
||||
*
|
||||
* @return featureScoreMap key: feature name, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = {
|
||||
Map(booster.getScore(featureNames, importanceType)
|
||||
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||
}
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
|
||||
@ -473,7 +473,7 @@ public class BoosterImplTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportance() throws XGBoostError {
|
||||
public void testGetFeatureScore() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@ -484,6 +484,54 @@ public class BoosterImplTest {
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceGain() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
Map<String, Double> scoreMap = booster.getScore(featureNames, "gain");
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceTotalGain() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_gain");
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceCover() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
Map<String, Double> scoreMap = booster.getScore(featureNames, "cover");
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportanceTotalCover() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_cover");
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user