[jvm-packages] Add methods operating attributes of booster in jvm package, which follow API design in python package. (#4336)
This commit is contained in:
parent
9080bba815
commit
60a9af567c
@ -116,6 +116,60 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get attributes stored in the Booster as a Map.
|
||||
*
|
||||
* @return A map contain attribute pairs.
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final Map<String, String> getAttrs() throws XGBoostError {
|
||||
String[][] attrNames = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttrNames(handle, attrNames));
|
||||
Map<String, String> attrMap = new HashMap<>();
|
||||
for (String name: attrNames[0]) {
|
||||
attrMap.put(name, this.getAttr(name));
|
||||
}
|
||||
return attrMap;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get attribute from the Booster.
|
||||
*
|
||||
* @param key attribute key
|
||||
* @return attribute value
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final String getAttr(String key) throws XGBoostError {
|
||||
String[] attrValue = new String[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttr(handle, key, attrValue));
|
||||
return attrValue[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set attribute to the Booster.
|
||||
*
|
||||
* @param key attribute key
|
||||
* @param value attribute value
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public final void setAttr(String key, String value) throws XGBoostError {
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetAttr(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set attributes to the Booster.
|
||||
*
|
||||
* @param attrs attributes key-value map
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setAttrs(Map<String, String> attrs) throws XGBoostError {
|
||||
if (attrs != null) {
|
||||
for (Map.Entry<String, String> entry : attrs.entrySet()) {
|
||||
setAttr(entry.getKey(), entry.getValue());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the booster for one iteration.
|
||||
*
|
||||
|
||||
@ -114,6 +114,7 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterDumpModelExWithFeatures(
|
||||
long handle, String[] feature_names, int with_stats, String format, String[][] out_strings);
|
||||
|
||||
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
|
||||
@ -32,6 +32,48 @@ import scala.collection.mutable
|
||||
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
extends Serializable with KryoSerializable {
|
||||
|
||||
/**
|
||||
* Get attributes stored in the Booster as a Map.
|
||||
*
|
||||
* @return A map contain attribute pairs.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getAttrs: Map[String, String] = {
|
||||
booster.getAttrs.asScala.toMap
|
||||
}
|
||||
|
||||
/**
|
||||
* Get attribute from the Booster.
|
||||
*
|
||||
* @param key attr name
|
||||
* @return attr value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getAttr(key: String): String = {
|
||||
booster.getAttr(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set attribute to the Booster.
|
||||
*
|
||||
* @param key attr name
|
||||
* @param value attr value
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setAttr(key: String, value: String): Unit = {
|
||||
booster.setAttr(key, value)
|
||||
}
|
||||
|
||||
/**
|
||||
* set attributes
|
||||
*
|
||||
* @param params attributes key-value map
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def setAttrs(params: Map[String, String]): Unit = {
|
||||
booster.setAttrs(params.asJava)
|
||||
}
|
||||
|
||||
/**
|
||||
* Set parameter to the Booster.
|
||||
*
|
||||
|
||||
@ -706,6 +706,68 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttrNames
|
||||
* Signature: (I[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNames
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
bst_ulong len = 0;
|
||||
char **result;
|
||||
int ret = XGBoosterGetAttrNames(handle, &len, (const char ***) &result);
|
||||
|
||||
jsize jlen = (jsize) len;
|
||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||
for(int i=0 ; i<jlen; i++) {
|
||||
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
|
||||
}
|
||||
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttr
|
||||
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char* key = jenv->GetStringUTFChars(jkey, 0);
|
||||
const char* result;
|
||||
int success;
|
||||
int ret = XGBoosterGetAttr(handle, key, &result, &success);
|
||||
//release
|
||||
if (key) jenv->ReleaseStringUTFChars(jkey, key);
|
||||
|
||||
if (success > 0) {
|
||||
jstring jret = jenv->NewStringUTF(result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jret);
|
||||
}
|
||||
|
||||
return ret;
|
||||
};
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterSetAttr
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) {
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
const char* key = jenv->GetStringUTFChars(jkey, 0);
|
||||
const char* value = jenv->GetStringUTFChars(jvalue, 0);
|
||||
int ret = XGBoosterSetAttr(handle, key, value);
|
||||
//release
|
||||
if (key) jenv->ReleaseStringUTFChars(jkey, key);
|
||||
if (value) jenv->ReleaseStringUTFChars(jvalue, value);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
|
||||
@ -231,6 +231,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures
|
||||
(JNIEnv *, jclass, jlong, jobjectArray, jint, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttrNames
|
||||
* Signature: (I[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNames
|
||||
(JNIEnv *, jclass, jlong, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttr
|
||||
|
||||
@ -617,4 +617,34 @@ public class BoosterImplTest {
|
||||
TestCase.assertTrue(booster1error == booster2error);
|
||||
TestCase.assertTrue(tempBoosterError > booster2error);
|
||||
}
|
||||
|
||||
/**
|
||||
* test set/get attributes to/from a booster
|
||||
*
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testSetAndGetAttrs() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
booster.setAttr("testKey1", "testValue1");
|
||||
TestCase.assertEquals(booster.getAttr("testKey1"), "testValue1");
|
||||
booster.setAttr("testKey1", "testValue2");
|
||||
TestCase.assertEquals(booster.getAttr("testKey1"), "testValue2");
|
||||
|
||||
booster.setAttrs(new HashMap<String, String>(){{
|
||||
put("aa", "AA");
|
||||
put("bb", "BB");
|
||||
put("cc", "CC");
|
||||
}});
|
||||
|
||||
Map<String, String> attr = booster.getAttrs();
|
||||
TestCase.assertEquals(attr.size(), 4);
|
||||
TestCase.assertEquals(attr.get("testKey1"), "testValue2");
|
||||
TestCase.assertEquals(attr.get("aa"), "AA");
|
||||
TestCase.assertEquals(attr.get("bb"), "BB");
|
||||
TestCase.assertEquals(attr.get("cc"), "CC");
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user