Set feature_names and feature_types in jvm-packages (#9364)

* 1. Add parameters to set feature names and feature types
2. Save feature names and feature types to native json model

* Change serialization and deserialization format to ubj.
This commit is contained in:
jinmfeng001
2023-07-12 15:18:46 +08:00
committed by GitHub
parent 3632242e0b
commit a1367ea1f8
12 changed files with 295 additions and 8 deletions

View File

@@ -162,6 +162,51 @@ public class Booster implements Serializable, KryoSerializable {
}
}
/**
* Get feature names from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureNames() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
return out;
}
/**
* Set feature names to the Booster.
*
* @param featureNames
* @throws XGBoostError
*/
public void setFeatureNames(String[] featureNames) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_name", featureNames));
}
/**
* Get feature types from the Booster.
* @return
* @throws XGBoostError
*/
public final String[] getFeatureTypes() throws XGBoostError {
int numFeature = (int) getNumFeature();
String[] out = new String[numFeature];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
return out;
}
/**
* Set feature types to the Booster.
* @param featureTypes
* @throws XGBoostError
*/
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
handle, "feature_type", featureTypes));
}
/**
* Update the booster for one iteration.
*
@@ -744,7 +789,7 @@ public class Booster implements Serializable, KryoSerializable {
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeInt(version);
out.writeObject(this.toByteArray());
out.writeObject(this.toByteArray("ubj"));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
@@ -780,7 +825,7 @@ public class Booster implements Serializable, KryoSerializable {
@Override
public void write(Kryo kryo, Output output) {
try {
byte[] serObj = this.toByteArray();
byte[] serObj = this.toByteArray("ubj");
int serObjSize = serObj.length;
output.writeInt(serObjSize);
output.writeInt(version);

View File

@@ -198,6 +198,8 @@ public class XGBoost {
if (booster == null) {
// Start training on a new booster
booster = new Booster(params, allMats);
booster.setFeatureNames(dtrain.getFeatureNames());
booster.setFeatureTypes(dtrain.getFeatureTypes());
booster.loadRabitCheckpoint();
} else {
// Start training on an existing booster

View File

@@ -164,4 +164,8 @@ class XGBoostJNI {
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
}

View File

@@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.setBaseMargin(column)
}
/**
* set feature names
* @param values feature names
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def setFeatureNames(values: Array[String]): Unit = {
jDMatrix.setFeatureNames(values)
}
/**
* set feature types
* @param values feature types
* @throws ml.dmlc.xgboost4j.java.XGBoostError
*/
@throws(classOf[XGBoostError])
def setFeatureTypes(values: Array[String]): Unit = {
jDMatrix.setFeatureTypes(values)
}
/**
* Get group sizes of DMatrix (used for ranking)
*/
@@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
jDMatrix.getBaseMargin
}
/**
* get feature names
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def getFeatureNames: Array[String] = {
jDMatrix.getFeatureNames
}
/**
* get feature types
* @throws ml.dmlc.xgboost4j.java.XGBoostError
* @return
*/
@throws(classOf[XGBoostError])
def getFeatureTypes: Array[String] = {
jDMatrix.getFeatureTypes
}
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
*