[jvm-packages] add jni for setting feature name and type (#7966)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
Copyright (c) 2014-2022 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -236,6 +236,66 @@ public class DMatrix {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||
}
|
||||
|
||||
private void setXGBDMatrixFeatureInfo(String type, String[] values) throws XGBoostError {
|
||||
if (type == null || type.isEmpty()) {
|
||||
throw new XGBoostError("Found empty type");
|
||||
}
|
||||
if (values == null || values.length == 0) {
|
||||
throw new XGBoostError("Found empty values");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetStrFeatureInfo(handle, type, values));
|
||||
}
|
||||
|
||||
private String[] getXGBDMatrixFeatureInfo(String type) throws XGBoostError {
|
||||
if (type == null || type.isEmpty()) {
|
||||
throw new XGBoostError("Found empty type");
|
||||
}
|
||||
long[] outLen = new long[1];
|
||||
String[][] outValue = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetStrFeatureInfo(handle, type, outLen, outValue));
|
||||
|
||||
if (outLen[0] != outValue[0].length) {
|
||||
throw new RuntimeException("Failed to get " + type);
|
||||
}
|
||||
return outValue[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature names
|
||||
* @param values feature names to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureNames(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_name", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature names
|
||||
* @return an array of feature names to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public String[] getFeatureNames() throws XGBoostError {
|
||||
return getXGBDMatrixFeatureInfo("feature_name");
|
||||
}
|
||||
|
||||
/**
|
||||
* Set feature types
|
||||
* @param values feature types to be set
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public void setFeatureTypes(String[] values) throws XGBoostError {
|
||||
setXGBDMatrixFeatureInfo("feature_type", values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get feature types
|
||||
* @return an array of feature types to be returned
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public String[] getFeatureTypes() throws XGBoostError {
|
||||
return getXGBDMatrixFeatureInfo("feature_type");
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
*
|
||||
|
||||
@@ -82,6 +82,19 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||
|
||||
/**
|
||||
* Set the feature information
|
||||
* @param handle the DMatrix native address
|
||||
* @param field "feature_names" or "feature_types"
|
||||
* @param values an array of string
|
||||
* @return 0 when success, -1 when failure happens
|
||||
*/
|
||||
public final static native int XGDMatrixSetStrFeatureInfo(long handle, String field,
|
||||
String[] values);
|
||||
|
||||
public final static native int XGDMatrixGetStrFeatureInfo(long handle, String field,
|
||||
long[] outLength, String[][] outValues);
|
||||
|
||||
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
||||
|
||||
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
||||
@@ -143,4 +156,5 @@ class XGBoostJNI {
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user