Set feature_names and feature_types in jvm-packages (#9364)
* 1. Add parameters to set feature names and feature types 2. Save feature names and feature types to native json model * Change serialization and deserialization format to ubj.
This commit is contained in:
@@ -162,6 +162,51 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature names from the Booster.
|
||||
* @return
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public final String[] getFeatureNames() throws XGBoostError {
|
||||
int numFeature = (int) getNumFeature();
|
||||
String[] out = new String[numFeature];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature names to the Booster.
|
||||
*
|
||||
* @param featureNames
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureNames(String[] featureNames) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||
handle, "feature_name", featureNames));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature types from the Booster.
|
||||
* @return
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public final String[] getFeatureTypes() throws XGBoostError {
|
||||
int numFeature = (int) getNumFeature();
|
||||
String[] out = new String[numFeature];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
|
||||
return out;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature types to the Booster.
|
||||
* @param featureTypes
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||
handle, "feature_type", featureTypes));
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the booster for one iteration.
|
||||
*
|
||||
@@ -744,7 +789,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||
try {
|
||||
out.writeInt(version);
|
||||
out.writeObject(this.toByteArray());
|
||||
out.writeObject(this.toByteArray("ubj"));
|
||||
} catch (XGBoostError ex) {
|
||||
ex.printStackTrace();
|
||||
logger.error(ex.getMessage());
|
||||
@@ -780,7 +825,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
@Override
|
||||
public void write(Kryo kryo, Output output) {
|
||||
try {
|
||||
byte[] serObj = this.toByteArray();
|
||||
byte[] serObj = this.toByteArray("ubj");
|
||||
int serObjSize = serObj.length;
|
||||
output.writeInt(serObjSize);
|
||||
output.writeInt(version);
|
||||
|
||||
@@ -198,6 +198,8 @@ public class XGBoost {
|
||||
if (booster == null) {
|
||||
// Start training on a new booster
|
||||
booster = new Booster(params, allMats);
|
||||
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||
booster.loadRabitCheckpoint();
|
||||
} else {
|
||||
// Start training on an existing booster
|
||||
|
||||
@@ -164,4 +164,8 @@ class XGBoostJNI {
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
|
||||
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||
|
||||
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||
|
||||
}
|
||||
|
||||
@@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.setBaseMargin(column)
|
||||
}
|
||||
|
||||
/**
|
||||
* set feature names
|
||||
* @param values feature names
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setFeatureNames(values: Array[String]): Unit = {
|
||||
jDMatrix.setFeatureNames(values)
|
||||
}
|
||||
|
||||
/**
|
||||
* set feature types
|
||||
* @param values feature types
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setFeatureTypes(values: Array[String]): Unit = {
|
||||
jDMatrix.setFeatureTypes(values)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get group sizes of DMatrix (used for ranking)
|
||||
*/
|
||||
@@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
jDMatrix.getBaseMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* get feature names
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureNames: Array[String] = {
|
||||
jDMatrix.getFeatureNames
|
||||
}
|
||||
|
||||
/**
|
||||
* get feature types
|
||||
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||
* @return
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureTypes: Array[String] = {
|
||||
jDMatrix.getFeatureTypes
|
||||
}
|
||||
|
||||
/**
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
*
|
||||
|
||||
@@ -1148,3 +1148,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
|
||||
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||
jobjectArray jfeatures) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
|
||||
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||
|
||||
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);
|
||||
|
||||
std::vector<std::string> features;
|
||||
std::vector<char const*> features_char;
|
||||
|
||||
for (bst_ulong i = 0; i < feature_num; ++i) {
|
||||
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
|
||||
const char *s = jenv->GetStringUTFChars(jfeature, 0);
|
||||
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < features.size(); ++i) {
|
||||
features_char.push_back(features[i].c_str());
|
||||
}
|
||||
|
||||
int ret = XGBoosterSetStrFeatureInfo(
|
||||
handle, field, dmlc::BeginPtr(features_char), feature_num);
|
||||
JVM_CHECK_CALL(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetGtrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
||||
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||
jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||
|
||||
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||
|
||||
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);
|
||||
|
||||
const char **features;
|
||||
std::vector<char *> features_char;
|
||||
|
||||
int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
|
||||
(const char ***)&features);
|
||||
JVM_CHECK_CALL(ret);
|
||||
|
||||
for (bst_ulong i = 0; i < feature_num; i++) {
|
||||
jstring jfeature = jenv->NewStringUTF(features[i]);
|
||||
jenv->SetObjectArrayElement(jout, i, jfeature);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -383,6 +383,24 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixC
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetStrFeatureInfo
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -16,10 +16,7 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
@@ -122,6 +119,40 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
String[] featureNames = new String[126];
|
||||
String[] featureTypes = new String[126];
|
||||
for(int i = 0; i < 126; i++) {
|
||||
featureNames[i] = "test_feature_name_" + i;
|
||||
featureTypes[i] = "q";
|
||||
}
|
||||
trainMat.setFeatureNames(featureNames);
|
||||
testMat.setFeatureNames(featureNames);
|
||||
trainMat.setFeatureTypes(featureTypes);
|
||||
testMat.setFeatureTypes(featureTypes);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
// save and load, only json format save and load feature_name and feature_type
|
||||
File temp = File.createTempFile("temp", ".json");
|
||||
temp.deleteOnExit();
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
String modelString = new String(booster.toByteArray("json"));
|
||||
System.out.println(modelString);
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
|
||||
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
|
||||
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
|
||||
Reference in New Issue
Block a user