Set feature_names and feature_types in jvm-packages (#9364)

* 1. Add parameters to set feature names and feature types
2. Save feature names and feature types to native json model

* Change serialization and deserialization format to ubj.
This commit is contained in:
jinmfeng001
2023-07-12 15:18:46 +08:00
committed by GitHub
parent 3632242e0b
commit a1367ea1f8
12 changed files with 295 additions and 8 deletions

View File

@@ -16,10 +16,7 @@
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.*;
import junit.framework.TestCase;
import org.junit.Test;
@@ -122,6 +119,40 @@ public class BoosterImplTest {
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);
DMatrix testMat = new DMatrix(this.test_uri);
IEvaluation eval = new EvalError();
String[] featureNames = new String[126];
String[] featureTypes = new String[126];
for(int i = 0; i < 126; i++) {
featureNames[i] = "test_feature_name_" + i;
featureTypes[i] = "q";
}
trainMat.setFeatureNames(featureNames);
testMat.setFeatureNames(featureNames);
trainMat.setFeatureTypes(featureTypes);
testMat.setFeatureTypes(featureTypes);
Booster booster = trainBooster(trainMat, testMat);
// save and load, only json format save and load feature_name and feature_type
File temp = File.createTempFile("temp", ".json");
temp.deleteOnExit();
booster.saveModel(temp.getAbsolutePath());
String modelString = new String(booster.toByteArray("json"));
System.out.println(modelString);
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
float[][] predicts2 = bst2.predict(testMat, true, 0);
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithStream() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix(this.train_uri);