Set feature_names and feature_types in jvm-packages (#9364)
* 1. Add parameters to set feature names and feature types 2. Save feature names and feature types to native json model * Change serialization and deserialization format to ubj.
This commit is contained in:
@@ -16,10 +16,7 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
import org.junit.Test;
|
||||
@@ -122,6 +119,40 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
DMatrix testMat = new DMatrix(this.test_uri);
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
String[] featureNames = new String[126];
|
||||
String[] featureTypes = new String[126];
|
||||
for(int i = 0; i < 126; i++) {
|
||||
featureNames[i] = "test_feature_name_" + i;
|
||||
featureTypes[i] = "q";
|
||||
}
|
||||
trainMat.setFeatureNames(featureNames);
|
||||
testMat.setFeatureNames(featureNames);
|
||||
trainMat.setFeatureTypes(featureTypes);
|
||||
testMat.setFeatureTypes(featureTypes);
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
// save and load, only json format save and load feature_name and feature_type
|
||||
File temp = File.createTempFile("temp", ".json");
|
||||
temp.deleteOnExit();
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
String modelString = new String(booster.toByteArray("json"));
|
||||
System.out.println(modelString);
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
|
||||
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
|
||||
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||
|
||||
Reference in New Issue
Block a user