Set feature_names and feature_types in jvm-packages (#9364)
* 1. Add parameters to set feature names and feature types 2. Save feature names and feature types to native json model * Change serialization and deserialization format to ubj.
This commit is contained in:
parent
3632242e0b
commit
a1367ea1f8
@ -74,7 +74,9 @@ private[scala] case class XGBoostExecutionParams(
|
|||||||
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
earlyStoppingParams: XGBoostExecutionEarlyStoppingParams,
|
||||||
cacheTrainingSet: Boolean,
|
cacheTrainingSet: Boolean,
|
||||||
treeMethod: Option[String],
|
treeMethod: Option[String],
|
||||||
isLocal: Boolean) {
|
isLocal: Boolean,
|
||||||
|
featureNames: Option[Array[String]],
|
||||||
|
featureTypes: Option[Array[String]]) {
|
||||||
|
|
||||||
private var rawParamMap: Map[String, Any] = _
|
private var rawParamMap: Map[String, Any] = _
|
||||||
|
|
||||||
@ -213,6 +215,13 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false)
|
||||||
.asInstanceOf[Boolean]
|
.asInstanceOf[Boolean]
|
||||||
|
|
||||||
|
val featureNames = if (overridedParams.contains("feature_names")) {
|
||||||
|
Some(overridedParams("feature_names").asInstanceOf[Array[String]])
|
||||||
|
} else None
|
||||||
|
val featureTypes = if (overridedParams.contains("feature_types")){
|
||||||
|
Some(overridedParams("feature_types").asInstanceOf[Array[String]])
|
||||||
|
} else None
|
||||||
|
|
||||||
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
val xgbExecParam = XGBoostExecutionParams(nWorkers, round, useExternalMemory, obj, eval,
|
||||||
missing, allowNonZeroForMissing, trackerConf,
|
missing, allowNonZeroForMissing, trackerConf,
|
||||||
checkpointParam,
|
checkpointParam,
|
||||||
@ -220,7 +229,10 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
|
|||||||
xgbExecEarlyStoppingParams,
|
xgbExecEarlyStoppingParams,
|
||||||
cacheTrainingSet,
|
cacheTrainingSet,
|
||||||
treeMethod,
|
treeMethod,
|
||||||
isLocal)
|
isLocal,
|
||||||
|
featureNames,
|
||||||
|
featureTypes
|
||||||
|
)
|
||||||
xgbExecParam.setRawParamMap(overridedParams)
|
xgbExecParam.setRawParamMap(overridedParams)
|
||||||
xgbExecParam
|
xgbExecParam
|
||||||
}
|
}
|
||||||
@ -531,6 +543,16 @@ private object Watches {
|
|||||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||||
|
|
||||||
|
if (xgbExecutionParams.featureNames.isDefined) {
|
||||||
|
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||||
|
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (xgbExecutionParams.featureTypes.isDefined) {
|
||||||
|
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||||
|
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||||
|
}
|
||||||
|
|
||||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -643,6 +665,15 @@ private object Watches {
|
|||||||
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
if (trainMargin.isDefined) trainMatrix.setBaseMargin(trainMargin.get)
|
||||||
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
if (testMargin.isDefined) testMatrix.setBaseMargin(testMargin.get)
|
||||||
|
|
||||||
|
if (xgbExecutionParams.featureNames.isDefined) {
|
||||||
|
trainMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||||
|
testMatrix.setFeatureNames(xgbExecutionParams.featureNames.get)
|
||||||
|
}
|
||||||
|
if (xgbExecutionParams.featureTypes.isDefined) {
|
||||||
|
trainMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||||
|
testMatrix.setFeatureTypes(xgbExecutionParams.featureTypes.get)
|
||||||
|
}
|
||||||
|
|
||||||
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
new Watches(Array(trainMatrix, testMatrix), Array("train", "test"), cacheDirName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -139,6 +139,12 @@ class XGBoostClassifier (
|
|||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||||
set(singlePrecisionHistogram, value)
|
set(singlePrecisionHistogram, value)
|
||||||
|
|
||||||
|
def setFeatureNames(value: Array[String]): this.type =
|
||||||
|
set(featureNames, value)
|
||||||
|
|
||||||
|
def setFeatureTypes(value: Array[String]): this.type =
|
||||||
|
set(featureTypes, value)
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
|
|||||||
@ -141,6 +141,12 @@ class XGBoostRegressor (
|
|||||||
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
def setSinglePrecisionHistogram(value: Boolean): this.type =
|
||||||
set(singlePrecisionHistogram, value)
|
set(singlePrecisionHistogram, value)
|
||||||
|
|
||||||
|
def setFeatureNames(value: Array[String]): this.type =
|
||||||
|
set(featureNames, value)
|
||||||
|
|
||||||
|
def setFeatureTypes(value: Array[String]): this.type =
|
||||||
|
set(featureTypes, value)
|
||||||
|
|
||||||
// called at the start of fit/train when 'eval_metric' is not defined
|
// called at the start of fit/train when 'eval_metric' is not defined
|
||||||
private def setupDefaultEvalMetric(): String = {
|
private def setupDefaultEvalMetric(): String = {
|
||||||
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")
|
||||||
|
|||||||
@ -177,6 +177,21 @@ private[spark] trait GeneralParams extends Params {
|
|||||||
|
|
||||||
final def getSeed: Long = $(seed)
|
final def getSeed: Long = $(seed)
|
||||||
|
|
||||||
|
/** Feature's name, it will be set to DMatrix and Booster, and in the final native json model.
|
||||||
|
* In native code, the parameter name is feature_name.
|
||||||
|
* */
|
||||||
|
final val featureNames = new StringArrayParam(this, "feature_names",
|
||||||
|
"an array of feature names")
|
||||||
|
|
||||||
|
final def getFeatureNames: Array[String] = $(featureNames)
|
||||||
|
|
||||||
|
/** Feature types, q is numeric and c is categorical.
|
||||||
|
* In native code, the parameter name is feature_type
|
||||||
|
* */
|
||||||
|
final val featureTypes = new StringArrayParam(this, "feature_types",
|
||||||
|
"an array of feature types")
|
||||||
|
|
||||||
|
final def getFeatureTypes: Array[String] = $(featureTypes)
|
||||||
}
|
}
|
||||||
|
|
||||||
trait HasLeafPredictionCol extends Params {
|
trait HasLeafPredictionCol extends Params {
|
||||||
|
|||||||
@ -27,6 +27,8 @@ import org.apache.commons.io.IOUtils
|
|||||||
|
|
||||||
import org.apache.spark.Partitioner
|
import org.apache.spark.Partitioner
|
||||||
import org.apache.spark.ml.feature.VectorAssembler
|
import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
import org.json4s.{DefaultFormats, Formats}
|
||||||
|
import org.json4s.jackson.parseJson
|
||||||
|
|
||||||
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {
|
||||||
|
|
||||||
@ -453,4 +455,26 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
|
|||||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||||
nativeUbjModelPath))
|
nativeUbjModelPath))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("native json model file should store feature_name and feature_type") {
|
||||||
|
val featureNames = (1 to 33).map(idx => s"feature_${idx}").toArray
|
||||||
|
val featureTypes = (1 to 33).map(idx => "q").toArray
|
||||||
|
val paramMap = Map("eta" -> "0.1", "max_depth" -> "6", "silent" -> "1",
|
||||||
|
"objective" -> "multi:softprob", "num_class" -> "6", "num_round" -> 5,
|
||||||
|
"num_workers" -> numWorkers, "tree_method" -> treeMethod
|
||||||
|
)
|
||||||
|
val trainingDF = buildDataFrame(MultiClassification.train)
|
||||||
|
val xgb = new XGBoostClassifier(paramMap)
|
||||||
|
.setFeatureNames(featureNames)
|
||||||
|
.setFeatureTypes(featureTypes)
|
||||||
|
val model = xgb.fit(trainingDF)
|
||||||
|
val modelStr = new String(model._booster.toByteArray("json"))
|
||||||
|
System.out.println(modelStr)
|
||||||
|
val jsonModel = parseJson(modelStr)
|
||||||
|
implicit val formats: Formats = DefaultFormats
|
||||||
|
val featureNamesInModel = (jsonModel \ "learner" \ "feature_names").extract[List[String]]
|
||||||
|
val featureTypesInModel = (jsonModel \ "learner" \ "feature_types").extract[List[String]]
|
||||||
|
assert(featureNamesInModel.length == 33)
|
||||||
|
assert(featureTypesInModel.length == 33)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -162,6 +162,51 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get feature names from the Booster.
|
||||||
|
* @return
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public final String[] getFeatureNames() throws XGBoostError {
|
||||||
|
int numFeature = (int) getNumFeature();
|
||||||
|
String[] out = new String[numFeature];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_name", out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set feature names to the Booster.
|
||||||
|
*
|
||||||
|
* @param featureNames
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public void setFeatureNames(String[] featureNames) throws XGBoostError {
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||||
|
handle, "feature_name", featureNames));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get feature types from the Booster.
|
||||||
|
* @return
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public final String[] getFeatureTypes() throws XGBoostError {
|
||||||
|
int numFeature = (int) getNumFeature();
|
||||||
|
String[] out = new String[numFeature];
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetStrFeatureInfo(handle, "feature_type", out));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set feature types to the Booster.
|
||||||
|
* @param featureTypes
|
||||||
|
* @throws XGBoostError
|
||||||
|
*/
|
||||||
|
public void setFeatureTypes(String[] featureTypes) throws XGBoostError {
|
||||||
|
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetStrFeatureInfo(
|
||||||
|
handle, "feature_type", featureTypes));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the booster for one iteration.
|
* Update the booster for one iteration.
|
||||||
*
|
*
|
||||||
@ -744,7 +789,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
|
||||||
try {
|
try {
|
||||||
out.writeInt(version);
|
out.writeInt(version);
|
||||||
out.writeObject(this.toByteArray());
|
out.writeObject(this.toByteArray("ubj"));
|
||||||
} catch (XGBoostError ex) {
|
} catch (XGBoostError ex) {
|
||||||
ex.printStackTrace();
|
ex.printStackTrace();
|
||||||
logger.error(ex.getMessage());
|
logger.error(ex.getMessage());
|
||||||
@ -780,7 +825,7 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
@Override
|
@Override
|
||||||
public void write(Kryo kryo, Output output) {
|
public void write(Kryo kryo, Output output) {
|
||||||
try {
|
try {
|
||||||
byte[] serObj = this.toByteArray();
|
byte[] serObj = this.toByteArray("ubj");
|
||||||
int serObjSize = serObj.length;
|
int serObjSize = serObj.length;
|
||||||
output.writeInt(serObjSize);
|
output.writeInt(serObjSize);
|
||||||
output.writeInt(version);
|
output.writeInt(version);
|
||||||
|
|||||||
@ -198,6 +198,8 @@ public class XGBoost {
|
|||||||
if (booster == null) {
|
if (booster == null) {
|
||||||
// Start training on a new booster
|
// Start training on a new booster
|
||||||
booster = new Booster(params, allMats);
|
booster = new Booster(params, allMats);
|
||||||
|
booster.setFeatureNames(dtrain.getFeatureNames());
|
||||||
|
booster.setFeatureTypes(dtrain.getFeatureTypes());
|
||||||
booster.loadRabitCheckpoint();
|
booster.loadRabitCheckpoint();
|
||||||
} else {
|
} else {
|
||||||
// Start training on an existing booster
|
// Start training on an existing booster
|
||||||
|
|||||||
@ -164,4 +164,8 @@ class XGBoostJNI {
|
|||||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
String featureJson, float missing, int nthread, long[] out);
|
String featureJson, float missing, int nthread, long[] out);
|
||||||
|
|
||||||
|
public final static native int XGBoosterSetStrFeatureInfo(long handle, String field, String[] features);
|
||||||
|
|
||||||
|
public final static native int XGBoosterGetStrFeatureInfo(long handle, String field, String[] out);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -205,6 +205,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
jDMatrix.setBaseMargin(column)
|
jDMatrix.setBaseMargin(column)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set feature names
|
||||||
|
* @param values feature names
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setFeatureNames(values: Array[String]): Unit = {
|
||||||
|
jDMatrix.setFeatureNames(values)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* set feature types
|
||||||
|
* @param values feature types
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def setFeatureTypes(values: Array[String]): Unit = {
|
||||||
|
jDMatrix.setFeatureTypes(values)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get group sizes of DMatrix (used for ranking)
|
* Get group sizes of DMatrix (used for ranking)
|
||||||
*/
|
*/
|
||||||
@ -243,6 +263,26 @@ class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
|||||||
jDMatrix.getBaseMargin
|
jDMatrix.getBaseMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get feature names
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getFeatureNames: Array[String] = {
|
||||||
|
jDMatrix.getFeatureNames
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get feature types
|
||||||
|
* @throws ml.dmlc.xgboost4j.java.XGBoostError
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
@throws(classOf[XGBoostError])
|
||||||
|
def getFeatureTypes: Array[String] = {
|
||||||
|
jDMatrix.getFeatureTypes
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -1148,3 +1148,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFea
|
|||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterSetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo(
|
||||||
|
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||||
|
jobjectArray jfeatures) {
|
||||||
|
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||||
|
|
||||||
|
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
|
|
||||||
|
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures);
|
||||||
|
|
||||||
|
std::vector<std::string> features;
|
||||||
|
std::vector<char const*> features_char;
|
||||||
|
|
||||||
|
for (bst_ulong i = 0; i < feature_num; ++i) {
|
||||||
|
jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i);
|
||||||
|
const char *s = jenv->GetStringUTFChars(jfeature, 0);
|
||||||
|
features.push_back(std::string(s, jenv->GetStringLength(jfeature)));
|
||||||
|
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < features.size(); ++i) {
|
||||||
|
features_char.push_back(features[i].c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int ret = XGBoosterSetStrFeatureInfo(
|
||||||
|
handle, field, dmlc::BeginPtr(features_char), feature_num);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterSetGtrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo(
|
||||||
|
JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield,
|
||||||
|
jobjectArray jout) {
|
||||||
|
BoosterHandle handle = (BoosterHandle)jhandle;
|
||||||
|
|
||||||
|
const char *field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
|
|
||||||
|
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout);
|
||||||
|
|
||||||
|
const char **features;
|
||||||
|
std::vector<char *> features_char;
|
||||||
|
|
||||||
|
int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num,
|
||||||
|
(const char ***)&features);
|
||||||
|
JVM_CHECK_CALL(ret);
|
||||||
|
|
||||||
|
for (bst_ulong i = 0; i < feature_num; i++) {
|
||||||
|
jstring jfeature = jenv->NewStringUTF(features[i]);
|
||||||
|
jenv->SetObjectArrayElement(jout, i, jfeature);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|||||||
@ -383,6 +383,24 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixC
|
|||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterSetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
* Method: XGBoosterGetStrFeatureInfo
|
||||||
|
* Signature: (JLjava/lang/String;[Ljava/lang/String;])I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL
|
||||||
|
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo
|
||||||
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@ -16,10 +16,7 @@
|
|||||||
package ml.dmlc.xgboost4j.java;
|
package ml.dmlc.xgboost4j.java;
|
||||||
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.Arrays;
|
import java.util.*;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
import junit.framework.TestCase;
|
import junit.framework.TestCase;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
@ -122,6 +119,40 @@ public class BoosterImplTest {
|
|||||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void saveLoadModelWithFeaturesWithPath() throws XGBoostError, IOException {
|
||||||
|
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||||
|
DMatrix testMat = new DMatrix(this.test_uri);
|
||||||
|
IEvaluation eval = new EvalError();
|
||||||
|
|
||||||
|
String[] featureNames = new String[126];
|
||||||
|
String[] featureTypes = new String[126];
|
||||||
|
for(int i = 0; i < 126; i++) {
|
||||||
|
featureNames[i] = "test_feature_name_" + i;
|
||||||
|
featureTypes[i] = "q";
|
||||||
|
}
|
||||||
|
trainMat.setFeatureNames(featureNames);
|
||||||
|
testMat.setFeatureNames(featureNames);
|
||||||
|
trainMat.setFeatureTypes(featureTypes);
|
||||||
|
testMat.setFeatureTypes(featureTypes);
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
// save and load, only json format save and load feature_name and feature_type
|
||||||
|
File temp = File.createTempFile("temp", ".json");
|
||||||
|
temp.deleteOnExit();
|
||||||
|
booster.saveModel(temp.getAbsolutePath());
|
||||||
|
|
||||||
|
String modelString = new String(booster.toByteArray("json"));
|
||||||
|
System.out.println(modelString);
|
||||||
|
|
||||||
|
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||||
|
assert (Arrays.equals(bst2.toByteArray("ubj"), booster.toByteArray("ubj")));
|
||||||
|
assert (Arrays.equals(bst2.toByteArray("json"), booster.toByteArray("json")));
|
||||||
|
assert (Arrays.equals(bst2.toByteArray("deprecated"), booster.toByteArray("deprecated")));
|
||||||
|
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||||
|
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||||
DMatrix trainMat = new DMatrix(this.train_uri);
|
DMatrix trainMat = new DMatrix(this.train_uri);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user