revise current API
This commit is contained in:
@@ -129,6 +129,7 @@ public class RabitTracker {
|
||||
public boolean start() {
|
||||
if (startTrackerProcess()) {
|
||||
logger.debug("Tracker started, with env=" + envs.toString());
|
||||
System.out.println("Tracker started, with env=" + envs.toString());
|
||||
// also start a tracker logger
|
||||
Thread logger_thread = new Thread(new TrackerProcessLogger());
|
||||
logger_thread.setDaemon(true);
|
||||
|
||||
@@ -119,7 +119,7 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
|
||||
: Array[Array[Float]] = {
|
||||
: Array[Array[Float]] = {
|
||||
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
|
||||
}
|
||||
|
||||
@@ -178,6 +178,10 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
booster.toByteArray
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispose the booster when it is no longer needed
|
||||
*/
|
||||
|
||||
@@ -17,7 +17,10 @@ package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
@@ -70,11 +73,7 @@ public class BoosterImplTest {
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterBasic() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
private Booster trainBooster(DMatrix trainMat, DMatrix testMat) throws XGBoostError {
|
||||
//set params
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
@@ -92,10 +91,19 @@ public class BoosterImplTest {
|
||||
watches.put("test", testMat);
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
int round = 5;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
||||
return XGBoost.train(paramMap, trainMat, round, watches, null, null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterBasic() throws XGBoostError, IOException {
|
||||
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
|
||||
//predict raw output
|
||||
float[][] predicts = booster.predict(testMat, true, 0);
|
||||
@@ -104,14 +112,43 @@ public class BoosterImplTest {
|
||||
IEvaluation eval = new EvalError();
|
||||
//error must be less than 0.1
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithPath() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
IEvaluation eval = new EvalError();
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
// save and load
|
||||
File temp = File.createTempFile("temp", "model");
|
||||
temp.deleteOnExit();
|
||||
booster.saveModel(temp.getAbsolutePath());
|
||||
|
||||
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath()));
|
||||
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
|
||||
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
|
||||
float[][] predicts2 = bst2.predict(testMat, true, 0);
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void saveLoadModelWithStream() throws XGBoostError, IOException {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
|
||||
Path tempDir = Files.createTempDirectory("boosterTest-");
|
||||
File tempFile = Files.createTempFile("", "").toFile();
|
||||
booster.saveModel(new FileOutputStream(tempFile));
|
||||
IEvaluation eval = new EvalError();
|
||||
Booster loadedBooster = XGBoost.loadModel(new FileInputStream(tempFile));
|
||||
float originalPredictError = eval.eval(booster.predict(testMat, true), testMat);
|
||||
TestCase.assertTrue("originalPredictErr:" + originalPredictError,
|
||||
originalPredictError < 0.1f);
|
||||
float loadedPredictError = eval.eval(loadedBooster.predict(testMat, true), testMat);
|
||||
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -16,10 +16,14 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
import java.io.{FileOutputStream, FileInputStream, File}
|
||||
|
||||
import junit.framework.TestCase
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError
|
||||
|
||||
class ScalaBoosterImplSuite extends FunSuite {
|
||||
|
||||
private class EvalError extends EvalTrait {
|
||||
@@ -64,21 +68,58 @@ class ScalaBoosterImplSuite extends FunSuite {
|
||||
}
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
private def trainBooster(trainMat: DMatrix, testMat: DMatrix): Booster = {
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||
|
||||
val round = 2
|
||||
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
val predicts = booster.predict(testMat, true)
|
||||
val eval = new EvalError
|
||||
assert(eval.eval(predicts, testMat) < 0.1)
|
||||
}
|
||||
|
||||
test("save/load model with path") {
|
||||
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
val temp: File = File.createTempFile("temp", "model")
|
||||
temp.deleteOnExit()
|
||||
booster.saveModel(temp.getAbsolutePath)
|
||||
|
||||
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
|
||||
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
|
||||
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
|
||||
}
|
||||
|
||||
test("save/load model with stream") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
val eval = new EvalError
|
||||
val booster = trainBooster(trainMat, testMat)
|
||||
// save and load
|
||||
val temp: File = File.createTempFile("temp", "model")
|
||||
temp.deleteOnExit()
|
||||
booster.saveModel(new FileOutputStream(temp.getAbsolutePath))
|
||||
|
||||
val bst2: Booster = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath))
|
||||
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
|
||||
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
|
||||
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
|
||||
}
|
||||
|
||||
test("cross validation") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
|
||||
|
||||
Reference in New Issue
Block a user