[jvm-packages] Add methods operating attributes of booster in jvm package, which follow API design in python package. (#4336)

This commit is contained in:
Xu Xiao
2019-04-09 02:00:35 +08:00
committed by Nan Zhu
parent 9080bba815
commit 60a9af567c
6 changed files with 197 additions and 0 deletions

View File

@@ -116,6 +116,60 @@ public class Booster implements Serializable, KryoSerializable {
}
}
/**
* Get attributes stored in the Booster as a Map.
*
* @return A map contain attribute pairs.
* @throws XGBoostError native error
*/
public final Map<String, String> getAttrs() throws XGBoostError {
String[][] attrNames = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttrNames(handle, attrNames));
Map<String, String> attrMap = new HashMap<>();
for (String name: attrNames[0]) {
attrMap.put(name, this.getAttr(name));
}
return attrMap;
}
/**
* Get attribute from the Booster.
*
* @param key attribute key
* @return attribute value
* @throws XGBoostError native error
*/
public final String getAttr(String key) throws XGBoostError {
String[] attrValue = new String[1];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetAttr(handle, key, attrValue));
return attrValue[0];
}
/**
* Set attribute to the Booster.
*
* @param key attribute key
* @param value attribute value
* @throws XGBoostError native error
*/
public final void setAttr(String key, String value) throws XGBoostError {
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetAttr(handle, key, value));
}
/**
* Set attributes to the Booster.
*
* @param attrs attributes key-value map
* @throws XGBoostError native error
*/
public void setAttrs(Map<String, String> attrs) throws XGBoostError {
if (attrs != null) {
for (Map.Entry<String, String> entry : attrs.entrySet()) {
setAttr(entry.getKey(), entry.getValue());
}
}
}
/**
* Update the booster for one iteration.
*

View File

@@ -114,6 +114,7 @@ class XGBoostJNI {
public final static native int XGBoosterDumpModelExWithFeatures(
long handle, String[] feature_names, int with_stats, String format, String[][] out_strings);
public final static native int XGBoosterGetAttrNames(long handle, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);

View File

@@ -32,6 +32,48 @@ import scala.collection.mutable
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {
/**
* Get attributes stored in the Booster as a Map.
*
* @return A map contain attribute pairs.
*/
@throws(classOf[XGBoostError])
def getAttrs: Map[String, String] = {
booster.getAttrs.asScala.toMap
}
/**
* Get attribute from the Booster.
*
* @param key attr name
* @return attr value
*/
@throws(classOf[XGBoostError])
def getAttr(key: String): String = {
booster.getAttr(key)
}
/**
* Set attribute to the Booster.
*
* @param key attr name
* @param value attr value
*/
@throws(classOf[XGBoostError])
def setAttr(key: String, value: String): Unit = {
booster.setAttr(key, value)
}
/**
* set attributes
*
* @param params attributes key-value map
*/
@throws(classOf[XGBoostError])
def setAttrs(params: Map[String, String]): Unit = {
booster.setAttrs(params.asJava)
}
/**
* Set parameter to the Booster.
*