example of DistTrainWithSpark and trigger job with foreachPartition

This commit is contained in:
CodingCat 2016-03-06 10:16:11 -05:00
parent f768edfede
commit 808e30f9fc
13 changed files with 588 additions and 867 deletions

View File

@ -25,7 +25,7 @@
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<artifactId>xgboost4j-spark</artifactId>
<version>0.1</version>
</dependency>
<dependency>

View File

@ -0,0 +1,74 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala.spark.demo
import java.io.File
import scala.collection.mutable.ListBuffer
import scala.io.Source
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.DenseVector
import org.apache.spark.mllib.regression.LabeledPoint
import ml.dmlc.xgboost4j.scala.DMatrix
import ml.dmlc.xgboost4j.scala.spark.XGBoost
object DistTrainWithSpark {
private def readFile(filePath: String): List[LabeledPoint] = {
val file = Source.fromFile(new File(filePath))
val sampleList = new ListBuffer[LabeledPoint]
for (sample <- file.getLines()) {
sampleList += fromSVMStringToLabeledPoint(sample)
}
sampleList.toList
}
private def fromSVMStringToLabeledPoint(line: String): LabeledPoint = {
val labelAndFeatures = line.split(" ")
val label = labelAndFeatures(0).toInt
val features = labelAndFeatures.tail
val denseFeature = new Array[Double](129)
for (feature <- features) {
val idAndValue = feature.split(":")
denseFeature(idAndValue(0).toInt) = idAndValue(1).toDouble
}
LabeledPoint(label, new DenseVector(denseFeature))
}
def main(args: Array[String]): Unit = {
import ml.dmlc.xgboost4j.scala.spark.DataUtils._
if (args.length != 4) {
println(
"usage: program number_of_trainingset_partitions num_of_rounds training_path test_path")
sys.exit(1)
}
val sc = new SparkContext()
val inputTrainPath = args(2)
val inputTestPath = args(3)
val trainingLabeledPoints = readFile(inputTrainPath)
val trainRDD = sc.parallelize(trainingLabeledPoints, args(0).toInt)
val testLabeledPoints = readFile(inputTestPath).iterator
val testMatrix = new DMatrix(testLabeledPoints, null)
val booster = XGBoost.train(trainRDD,
List("eta" -> "1", "max_depth" -> "2", "silent" -> "0",
"objective" -> "binary:logistic").toMap, args(1).toInt, null, null)
booster.map(boosterInstance => boosterInstance.predict(testMatrix))
}
}

View File

@ -23,11 +23,16 @@ import org.apache.spark.mllib.regression.{LabeledPoint => SparkLabeledPoint}
import ml.dmlc.xgboost4j.LabeledPoint
private[spark] object DataUtils extends Serializable {
object DataUtils extends Serializable {
implicit def fromSparkToXGBoostLabeledPointsAsJava(
sps: Iterator[SparkLabeledPoint]): java.util.Iterator[LabeledPoint] = {
fromSparkToXGBoostLabeledPoints(sps).asJava
}
implicit def fromSparkToXGBoostLabeledPoints(sps: Iterator[SparkLabeledPoint]):
java.util.Iterator[LabeledPoint] = {
(for (p <- sps) yield {
Iterator[LabeledPoint] = {
for (p <- sps) yield {
p.features match {
case denseFeature: DenseVector =>
LabeledPoint.fromDenseVector(p.label.toFloat, denseFeature.values.map(_.toFloat))
@ -35,17 +40,6 @@ private[spark] object DataUtils extends Serializable {
LabeledPoint.fromSparseVector(p.label.toFloat, sparseFeature.indices,
sparseFeature.values.map(_.toFloat))
}
}).asJava
}
private def fetchUpdateFromSparseVector(sparseFeature: SparseVector): (List[Int], List[Float]) = {
(sparseFeature.indices.toList, sparseFeature.values.map(_.toFloat).toList)
}
private def fetchUpdateFromVector(feature: Vector) = feature match {
case denseFeature: DenseVector =>
fetchUpdateFromSparseVector(denseFeature.toSparse)
case sparseFeature: SparseVector =>
fetchUpdateFromSparseVector(sparseFeature)
}
}
}

View File

@ -61,7 +61,8 @@ object XGBoost extends Serializable {
require(tracker.start(), "FAULT: Failed to start tracker")
boosters = buildDistributedBoosters(trainingData, configMap, numWorkers, round, obj, eval)
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
boosters.foreachPartition(_ => ())
println("=====finished training=====")
val booster = boosters.first()
val returnVal = tracker.waitFor()
logger.info(s"Rabit returns with exit code $returnVal")

View File

@ -1,10 +1,12 @@
package ml.dmlc.xgboost4j;
import java.io.Serializable;
/**
* Labeled data point for training examples.
* Represent a sparse training instance.
*/
public class LabeledPoint {
public class LabeledPoint implements Serializable {
/** Label of the point */
public float label;
/** Weight of this data point */

View File

@ -1,41 +1,140 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public interface Booster extends Serializable {
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
*
* @author hzx
*/
public class Booster implements Serializable {
private static final Log logger = LogFactory.getLog(Booster.class);
long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
Booster(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
init(dMatrixs);
setParam("seed", "0");
setParams(params);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
Booster(Map<String, Object> params, String modelPath) throws XGBoostError {
init(null);
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed", "0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if (dMatrixs != null) {
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
*
* @param key param name
* @param value param value
* @throws XGBoostError native error
*/
void setParam(String key, String value) throws XGBoostError;
public final void setParam(String key, String value) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
*
* @param params parameters key-value map
* @throws XGBoostError native error
*/
void setParams(Map<String, Object> params) throws XGBoostError;
public void setParams(Map<String, Object> params) throws XGBoostError {
if (params != null) {
for (Map.Entry<String, Object> entry : params.entrySet()) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
* @throws XGBoostError native error
*/
void update(DMatrix dtrain, int iter) throws XGBoostError;
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
* @throws XGBoostError native error
*/
void update(DMatrix dtrain, IObjective obj) throws XGBoostError;
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
@ -43,8 +142,16 @@ public interface Booster extends Serializable {
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError;
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
/**
* evaluate with given dmatrixs.
@ -53,8 +160,15 @@ public interface Booster extends Serializable {
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws XGBoostError native error
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError;
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
@ -63,17 +177,64 @@ public interface Booster extends Serializable {
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
* @throws XGBoostError native error
*/
String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError;
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
}
return evalInfo;
}
/**
* base function for Predict
*
* @param data data
* @param outPutMargin output margin
* @param treeLimit limit number of trees
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if (outPutMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data) throws XGBoostError;
public float[][] predict(DMatrix data) throws XGBoostError {
return pred(data, false, 0, false);
}
/**
* Predict with data
@ -81,9 +242,11 @@ public interface Booster extends Serializable {
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError;
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
@ -92,31 +255,189 @@ public interface Booster extends Serializable {
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError;
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index
* of each sample in each tree.
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
* in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError;
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath, the model path support depends on the path support
* in libxgboost. For example, if we want to save to hdfs, libxgboost need to be
* compiled with HDFS support.
* See also toByteArray
* save model to modelPath
*
* @param modelPath model path
*/
void saveModel(String modelPath) throws XGBoostError;
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
private void loadModel(String modelPath) {
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* get the dump of the model as a string array
*
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException file not found
* @throws UnsupportedEncodingException unsupported feature
* @throws IOException error with model writing
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException exception
* @throws UnsupportedEncodingException exception
* @throws IOException exception
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore() throws XGBoostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* Save the model as byte array representation.
@ -127,41 +448,77 @@ public interface Booster extends Serializable {
* @return the saved byte array.
* @throws XGBoostError
*/
byte[] toByteArray() throws XGBoostError;
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError;
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void dumpModel(String modelPath, String featureMap, boolean withStats)
throws IOException, XGBoostError;
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
* get importance of each feature
* transfer DMatrix array to handle array (used for native functions)
*
* @return featureMap key: feature index, value: feature importance score
* @param dmatrixs
* @return handle array for input dmatrixs
*/
Map<String, Integer> getFeatureScore() throws XGBoostError ;
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError;
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
void dispose();
private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
public synchronized void dispose() {
if (handle != 0L) {
XGBoostJNI.XGBoosterFree(handle);
handle = 0;
}
}
}

View File

@ -1,5 +1,6 @@
package ml.dmlc.xgboost4j.java;
import java.io.Serializable;
import java.util.Iterator;
import ml.dmlc.xgboost4j.LabeledPoint;
@ -56,7 +57,7 @@ class DataBatch {
return b;
}
static class BatchIterator implements Iterator<DataBatch> {
static class BatchIterator implements Iterator<DataBatch>, Serializable {
private Iterator<LabeledPoint> base;
private int batchSize;

View File

@ -1,525 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
*
* @author hzx
*/
class JavaBoosterImpl implements Booster {
private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class);
long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
*
* @param params parameters
* @param dMatrixs DMatrix array
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, DMatrix[] dMatrixs) throws XGBoostError {
init(dMatrixs);
setParam("seed", "0");
setParams(params);
}
/**
* load model from modelPath
*
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws XGBoostError native error
*/
JavaBoosterImpl(Map<String, Object> params, String modelPath) throws XGBoostError {
init(null);
if (modelPath == null) {
throw new NullPointerException("modelPath : null");
}
loadModel(modelPath);
setParam("seed", "0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) throws XGBoostError {
long[] handles = null;
if (dMatrixs != null) {
handles = dmatrixsToHandles(dMatrixs);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
*
* @param key param name
* @param value param value
* @throws XGBoostError native error
*/
public final void setParam(String key, String value) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
*
* @param params parameters key-value map
* @throws XGBoostError native error
*/
public void setParams(Map<String, Object> params) throws XGBoostError {
if (params != null) {
for (Map.Entry<String, Object> entry : params.entrySet()) {
setParam(entry.getKey(), entry.getValue().toString());
}
}
}
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, IObjective obj) throws XGBoostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws XGBoostError native error
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
if (grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad,
hess));
}
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
* @throws XGBoostError native error
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval)
throws XGBoostError {
String evalInfo = "";
for (int i = 0; i < evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult);
}
return evalInfo;
}
/**
* base function for Predict
*
* @param data data
* @param outPutMargin output margin
* @param treeLimit limit number of trees
* @param predLeaf prediction minimum to keep leafs
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit,
boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if (outPutMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data) throws XGBoostError {
return pred(data, false, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index
* of each sample in each tree.
* Note that the leaf index of a tree is unique per tree, so you may find leaf 1
* in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath
*
* @param modelPath model path
*/
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
private void loadModel(String modelPath) {
XGBoostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
*
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
* get the dump of the model as a string array
*
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws XGBoostError native error
*/
private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag,
modelInfos));
return modelInfos[0];
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException file not found
* @throws UnsupportedEncodingException unsupported feature
* @throws IOException error with model writing
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
* @throws FileNotFoundException exception
* @throws UnsupportedEncodingException exception
* @throws IOException exception
* @throws XGBoostError native error
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws
IOException, XGBoostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for (int i = 0; i < modelInfos.length; i++) {
writer.write("booster [" + i + "]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore() throws XGBoostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* Save the model as byte array representation.
* Write these bytes to a file will give compatible format with other xgboost bindings.
*
* If java natively support HDFS file API, use toByteArray and write the ByteArray,
*
* @return the saved byte array.
* @throws XGBoostError
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
* transfer DMatrix array to handle array (used for native functions)
*
* @param dmatrixs
* @return handle array for input dmatrixs
*/
private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for (int i = 0; i < dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
// making Booster serializable
private void writeObject(java.io.ObjectOutputStream out) throws IOException {
try {
out.writeObject(this.toByteArray());
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
private void readObject(java.io.ObjectInputStream in)
throws IOException, ClassNotFoundException {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
throw new IOException(ex.toString());
}
}
@Override
protected void finalize() throws Throwable {
super.finalize();
dispose();
}
public synchronized void dispose() {
if (handle != 0L) {
XGBoostJNI.XGBoosterFree(handle);
handle = 0;
}
}
}

View File

@ -1,6 +1,7 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
import java.util.Map;
import org.apache.commons.logging.Log;
@ -9,7 +10,7 @@ import org.apache.commons.logging.LogFactory;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
public class Rabit implements Serializable {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {

View File

@ -71,7 +71,7 @@ public class XGBoost {
}
//initialize booster
JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
Booster booster = new Booster(params, allMats);
int version = booster.loadRabitCheckpoint();
@ -115,7 +115,7 @@ public class XGBoost {
public static Booster initBoostingModel(
Map<String, Object> params,
DMatrix[] dMatrixs) throws XGBoostError {
return new JavaBoosterImpl(params, dMatrixs);
return new Booster(params, dMatrixs);
}
/**
@ -127,7 +127,7 @@ public class XGBoost {
*/
public static Booster loadBoostModel(Map<String, Object> params, String modelPath)
throws XGBoostError {
return new JavaBoosterImpl(params, modelPath);
return new Booster(params, modelPath);
}
/**

View File

@ -16,172 +16,86 @@
package ml.dmlc.xgboost4j.scala
import java.io.IOException
import ml.dmlc.xgboost4j.java.XGBoostError
import ml.dmlc.xgboost4j.java
import scala.collection.JavaConverters._
import scala.collection.mutable
trait Booster extends Serializable {
class Booster private[xgboost4j](booster: java.Booster) extends Serializable {
def setParam(key: String, value: String): Unit = {
booster.setParam(key, value)
}
def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
def dumpModel(modelPath: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, withStats)
}
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, featureMap, withStats)
}
def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
def dispose: Unit = {
booster.dispose()
}
def predict(data: DMatrix): Array[Array[Float]] = {
booster.predict(data.jDMatrix)
}
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin)
}
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, treeLimit, predLeaf)
}
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
def getFeatureScore: mutable.Map[String, Integer] = {
booster.getFeatureScore.asScala
}
def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
override def finalize(): Unit = {
super.finalize()
dispose
}
/**
* set parameter
*
* @param key param name
* @param value param value
*/
@throws(classOf[XGBoostError])
def setParam(key: String, value: String)
/**
* set parameters
*
* @param params parameters key-value map
*/
@throws(classOf[XGBoostError])
def setParams(params: Map[String, AnyRef])
/**
* Update (one iteration)
*
* @param dtrain training data
* @param iter current iteration number
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, iter: Int)
/**
* update with customize obj func
*
* @param dtrain training data
* @param obj customized objective class
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait)
/**
* update with give grad and hess
*
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float])
/**
* evaluate with given dmatrixs.
*
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String
/**
* evaluate with given customized Evaluation class
*
* @param evalMatrixs evaluation matrix
* @param evalNames evaluation names
* @param eval custom evaluator
* @return eval information
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String
/**
* Predict with data
*
* @param data dmatrix storing the input
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]]
/**
* Predict with data
*
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees),
* nsample = data.numRow with each record indicating the predicted leaf index of
* each sample in each tree. Note that the leaf index of a tree is unique per
* tree, so you may find leaf 1 in both tree 1 and tree 0.
* @return predict result
* @throws XGBoostError native error
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]]
/**
* save model to modelPath
*
* @param modelPath model path
*/
@throws(classOf[XGBoostError])
def saveModel(modelPath: String)
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param withStats bool Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, withStats: Boolean)
/**
* Dump model into a text file.
*
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
* Controls whether the split statistics are output.
*/
@throws(classOf[IOException])
@throws(classOf[XGBoostError])
def dumpModel(modelPath: String, featureMap: String, withStats: Boolean)
/**
* get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore: mutable.Map[String, Integer]
/**
* get importance of each feature
*
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureMap: String): mutable.Map[String, Integer]
def dispose
}

View File

@ -1,99 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java
import scala.collection.JavaConverters._
import scala.collection.mutable
private[scala] class ScalaBoosterImpl private[xgboost4j](booster: java.Booster) extends Booster {
override def setParam(key: String, value: String): Unit = {
booster.setParam(key, value)
}
override def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
override def dumpModel(modelPath: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, withStats)
}
override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = {
booster.dumpModel(modelPath, featureMap, withStats)
}
override def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait):
String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval)
}
override def dispose: Unit = {
booster.dispose()
}
override def predict(data: DMatrix): Array[Array[Float]] = {
booster.predict(data.jDMatrix)
}
override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin)
}
override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = {
booster.predict(data.jDMatrix, treeLimit, predLeaf)
}
override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
override def getFeatureScore: mutable.Map[String, Integer] = {
booster.getFeatureScore.asScala
}
override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
override def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
override def finalize(): Unit = {
super.finalize()
dispose
}
}

View File

@ -16,9 +16,10 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.java.{XGBoost => JXGBoost}
object XGBoost {
def train(
@ -31,7 +32,7 @@ object XGBoost {
val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)}
val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava,
obj, eval)
new ScalaBoosterImpl(xgboostInJava)
new Booster(xgboostInJava)
}
def crossValidation(
@ -47,11 +48,11 @@ object XGBoost {
def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = {
val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix))
new ScalaBoosterImpl(xgboostInJava)
new Booster(xgboostInJava)
}
def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = {
val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath)
new ScalaBoosterImpl(xgboostInJava)
new Booster(xgboostInJava)
}
}