[jvm-packages] Updates to Java Booster to support other feature importance measures (#3801)
* Updates to Booster to support other feature importances * Add returns for Java methods * Pass Scala style checks * Pass Java style checks * Fix indents * Use class instead of enum * Return map string double * A no longer broken build, thanks to mvn package local build * Add a unit test to increase code coverage back * Address code review on main code * Add more unit tests for different feature importance scores * Address more CR
This commit is contained in:
parent
1f022929f4
commit
431c850c03
@ -16,9 +16,12 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
|
import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
import com.esotericsoftware.kryo.Kryo;
|
import com.esotericsoftware.kryo.Kryo;
|
||||||
import com.esotericsoftware.kryo.KryoSerializable;
|
import com.esotericsoftware.kryo.KryoSerializable;
|
||||||
@ -395,6 +398,25 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
return modelInfos[0];
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Supported feature importance types
|
||||||
|
*
|
||||||
|
* WEIGHT = Number of nodes that a feature was used to determine a split
|
||||||
|
* GAIN = Average information gain per split for a feature
|
||||||
|
* COVER = Average cover per split for a feature
|
||||||
|
* TOTAL_GAIN = Total information gain over all splits of a feature
|
||||||
|
* TOTAL_COVER = Total cover over all splits of a feature
|
||||||
|
*/
|
||||||
|
public static class FeatureImportanceType {
|
||||||
|
public static final String WEIGHT = "weight";
|
||||||
|
public static final String GAIN = "gain";
|
||||||
|
public static final String COVER = "cover";
|
||||||
|
public static final String TOTAL_GAIN = "total_gain";
|
||||||
|
public static final String TOTAL_COVER = "total_cover";
|
||||||
|
public static final Set<String> ACCEPTED_TYPES = new HashSet<>(
|
||||||
|
Arrays.asList(WEIGHT, GAIN, COVER, TOTAL_GAIN, TOTAL_COVER));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get importance of each feature with specified feature names.
|
* Get importance of each feature with specified feature names.
|
||||||
*
|
*
|
||||||
@ -403,6 +425,28 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
*/
|
*/
|
||||||
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
|
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
|
||||||
String[] modelInfos = getModelDump(featureNames, false);
|
String[] modelInfos = getModelDump(featureNames, false);
|
||||||
|
return getFeatureWeightsFromModel(modelInfos);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get importance of each feature
|
||||||
|
*
|
||||||
|
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||||
|
String[] modelInfos = getModelDump(featureMap, false);
|
||||||
|
return getFeatureWeightsFromModel(modelInfos);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the importance of each feature based purely on weights (number of splits)
|
||||||
|
*
|
||||||
|
* @return featureScoreMap key: feature index,
|
||||||
|
* value: feature importance score based on weight
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
private Map<String, Integer> getFeatureWeightsFromModel(String[] modelInfos) throws XGBoostError {
|
||||||
Map<String, Integer> featureScore = new HashMap<>();
|
Map<String, Integer> featureScore = new HashMap<>();
|
||||||
for (String tree : modelInfos) {
|
for (String tree : modelInfos) {
|
||||||
for (String node : tree.split("\n")) {
|
for (String node : tree.split("\n")) {
|
||||||
@ -423,30 +467,91 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get importance of each feature
|
* Get the feature importances for gain or cover (average or total)
|
||||||
*
|
*
|
||||||
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
|
* @return featureImportanceMap key: feature index,
|
||||||
|
* values: feature importance score based on gain or cover
|
||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
public Map<String, Double> getScore(
|
||||||
String[] modelInfos = getModelDump(featureMap, false);
|
String[] featureNames, String importanceType) throws XGBoostError {
|
||||||
Map<String, Integer> featureScore = new HashMap<>();
|
String[] modelInfos = getModelDump(featureNames, true);
|
||||||
for (String tree : modelInfos) {
|
return getFeatureImportanceFromModel(modelInfos, importanceType);
|
||||||
for (String node : tree.split("\n")) {
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the feature importances for gain or cover (average or total), with feature names
|
||||||
|
*
|
||||||
|
* @return featureImportanceMap key: feature name,
|
||||||
|
* values: feature importance score based on gain or cover
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
public Map<String, Double> getScore(
|
||||||
|
String featureMap, String importanceType) throws XGBoostError {
|
||||||
|
String[] modelInfos = getModelDump(featureMap, true);
|
||||||
|
return getFeatureImportanceFromModel(modelInfos, importanceType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the importance of each feature based on information gain or cover
|
||||||
|
*
|
||||||
|
* @return featureImportanceMap key: feature index, value: feature importance score
|
||||||
|
* based on information gain or cover
|
||||||
|
* @throws XGBoostError native error
|
||||||
|
*/
|
||||||
|
private Map<String, Double> getFeatureImportanceFromModel(
|
||||||
|
String[] modelInfos, String importanceType) throws XGBoostError {
|
||||||
|
if (!FeatureImportanceType.ACCEPTED_TYPES.contains(importanceType)) {
|
||||||
|
throw new AssertionError(String.format("Importance type %s is not supported",
|
||||||
|
importanceType));
|
||||||
|
}
|
||||||
|
Map<String, Double> importanceMap = new HashMap<>();
|
||||||
|
Map<String, Double> weightMap = new HashMap<>();
|
||||||
|
if (importanceType == FeatureImportanceType.WEIGHT) {
|
||||||
|
Map<String, Integer> importanceWeights = getFeatureWeightsFromModel(modelInfos);
|
||||||
|
for (String feature: importanceWeights.keySet()) {
|
||||||
|
importanceMap.put(feature, new Double(importanceWeights.get(feature)));
|
||||||
|
}
|
||||||
|
return importanceMap;
|
||||||
|
}
|
||||||
|
/* Each split in the tree has this text form:
|
||||||
|
"0:[f28<-9.53674316e-07] yes=1,no=2,missing=1,gain=4000.53101,cover=1628.25"
|
||||||
|
So the line has to be split according to whether cover or gain is desired */
|
||||||
|
String splitter = "gain=";
|
||||||
|
if (importanceType == FeatureImportanceType.COVER
|
||||||
|
|| importanceType == FeatureImportanceType.TOTAL_COVER) {
|
||||||
|
splitter = "cover=";
|
||||||
|
}
|
||||||
|
for (String tree: modelInfos) {
|
||||||
|
for (String node: tree.split("\n")) {
|
||||||
String[] array = node.split("\\[");
|
String[] array = node.split("\\[");
|
||||||
if (array.length == 1) {
|
if (array.length == 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
String fid = array[1].split("\\]")[0];
|
String[] fidWithImportance = array[1].split("\\]");
|
||||||
fid = fid.split("<")[0];
|
// Extract gain or cover from string after closing bracket
|
||||||
if (featureScore.containsKey(fid)) {
|
Double importance = Double.parseDouble(
|
||||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
fidWithImportance[1].split(splitter)[1].split(",")[0]
|
||||||
|
);
|
||||||
|
String fid = fidWithImportance[0].split("<")[0];
|
||||||
|
if (importanceMap.containsKey(fid)) {
|
||||||
|
importanceMap.put(fid, importance + importanceMap.get(fid));
|
||||||
|
weightMap.put(fid, 1d + weightMap.get(fid));
|
||||||
} else {
|
} else {
|
||||||
featureScore.put(fid, 1);
|
importanceMap.put(fid, importance);
|
||||||
|
weightMap.put(fid, 1d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return featureScore;
|
/* By default we calculate total gain and total cover.
|
||||||
|
Divide by the number of nodes per feature to get gain / cover */
|
||||||
|
if (importanceType == FeatureImportanceType.COVER
|
||||||
|
|| importanceType == FeatureImportanceType.GAIN) {
|
||||||
|
for (String fid: importanceMap.keySet()) {
|
||||||
|
importanceMap.put(fid, importanceMap.get(fid)/weightMap.get(fid));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return importanceMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -204,7 +204,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get importance of each feature
|
* Get importance of each feature based on weight only (number of splits)
|
||||||
*
|
*
|
||||||
* @return featureScoreMap key: feature index, value: feature importance score
|
* @return featureScoreMap key: feature index, value: feature importance score
|
||||||
*/
|
*/
|
||||||
@ -214,7 +214,8 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get importance of each feature with specified feature names.
|
* Get importance of each feature based on weight only
|
||||||
|
* (number of splits), with specified feature names.
|
||||||
*
|
*
|
||||||
* @return featureScoreMap key: feature name, value: feature importance score
|
* @return featureScoreMap key: feature name, value: feature importance score
|
||||||
*/
|
*/
|
||||||
@ -223,6 +224,31 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
|||||||
booster.getFeatureScore(featureNames).asScala
|
booster.getFeatureScore(featureNames).asScala
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get importance of each feature based on information gain or cover
|
||||||
|
* Supported: ["gain, "cover", "total_gain", "total_cover"]
|
||||||
|
*
|
||||||
|
* @return featureScoreMap key: feature index, value: feature importance score
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getScore(featureMap: String, importanceType: String): Map[String, Double] = {
|
||||||
|
Map(booster.getScore(featureMap, importanceType)
|
||||||
|
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get importance of each feature based on information gain or cover
|
||||||
|
* , with specified feature names.
|
||||||
|
* Supported: ["gain, "cover", "total_gain", "total_cover"]
|
||||||
|
*
|
||||||
|
* @return featureScoreMap key: feature name, value: feature importance score
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = {
|
||||||
|
Map(booster.getScore(featureNames, importanceType)
|
||||||
|
.asScala.mapValues(_.doubleValue).toSeq: _*)
|
||||||
|
}
|
||||||
|
|
||||||
def getVersion: Int = booster.getVersion
|
def getVersion: Int = booster.getVersion
|
||||||
|
|
||||||
def toByteArray: Array[Byte] = {
|
def toByteArray: Array[Byte] = {
|
||||||
|
|||||||
@ -473,7 +473,7 @@ public class BoosterImplTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testGetFeatureImportance() throws XGBoostError {
|
public void testGetFeatureScore() throws XGBoostError {
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
@ -484,6 +484,54 @@ public class BoosterImplTest {
|
|||||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetFeatureImportanceGain() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
String[] featureNames = new String[126];
|
||||||
|
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||||
|
Map<String, Double> scoreMap = booster.getScore(featureNames, "gain");
|
||||||
|
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetFeatureImportanceTotalGain() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
String[] featureNames = new String[126];
|
||||||
|
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||||
|
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_gain");
|
||||||
|
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetFeatureImportanceCover() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
String[] featureNames = new String[126];
|
||||||
|
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||||
|
Map<String, Double> scoreMap = booster.getScore(featureNames, "cover");
|
||||||
|
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGetFeatureImportanceTotalCover() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
String[] featureNames = new String[126];
|
||||||
|
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||||
|
Map<String, Double> scoreMap = booster.getScore(featureNames, "total_cover");
|
||||||
|
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user