revise current API

This commit is contained in:
CodingCat
2016-03-07 21:48:16 -05:00
parent 9911771b02
commit fa03aaeb63
9 changed files with 170 additions and 64 deletions

View File

@@ -129,6 +129,7 @@ public class RabitTracker {
public boolean start() {
if (startTrackerProcess()) {
logger.debug("Tracker started, with env=" + envs.toString());
System.out.println("Tracker started, with env=" + envs.toString());
// also start a tracker logger
Thread logger_thread = new Thread(new TrackerProcessLogger());
logger_thread.setDaemon(true);

View File

@@ -119,7 +119,7 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0)
: Array[Array[Float]] = {
: Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
@@ -178,6 +178,10 @@ class Booster private[xgboost4j](booster: JBooster) extends Serializable {
booster.getFeatureScore(featureMap).asScala
}
def toByteArray: Array[Byte] = {
booster.toByteArray
}
/**
* Dispose the booster when it is no longer needed
*/

View File

@@ -17,7 +17,10 @@ package ml.dmlc.xgboost4j.java;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -70,11 +73,7 @@ public class BoosterImplTest {
}
}
@Test
public void testBoosterBasic() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
private Booster trainBooster(DMatrix trainMat, DMatrix testMat) throws XGBoostError {
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
@@ -92,10 +91,19 @@ public class BoosterImplTest {
watches.put("test", testMat);
//set round
int round = 2;
int round = 5;
//train a boost model
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
return XGBoost.train(paramMap, trainMat, round, watches, null, null);
}
@Test
public void testBoosterBasic() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
//predict raw output
float[][] predicts = booster.predict(testMat, true, 0);
@@ -104,14 +112,43 @@ public class BoosterImplTest {
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithPath() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
IEvaluation eval = new EvalError();
Booster booster = trainBooster(trainMat, testMat);
// save and load
File temp = File.createTempFile("temp", "model");
temp.deleteOnExit();
booster.saveModel(temp.getAbsolutePath());
Booster bst2 = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath()));
Booster bst2 = XGBoost.loadModel(temp.getAbsolutePath());
assert (Arrays.equals(bst2.toByteArray(), booster.toByteArray()));
float[][] predicts2 = bst2.predict(testMat, true, 0);
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f);
}
@Test
public void saveLoadModelWithStream() throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
Path tempDir = Files.createTempDirectory("boosterTest-");
File tempFile = Files.createTempFile("", "").toFile();
booster.saveModel(new FileOutputStream(tempFile));
IEvaluation eval = new EvalError();
Booster loadedBooster = XGBoost.loadModel(new FileInputStream(tempFile));
float originalPredictError = eval.eval(booster.predict(testMat, true), testMat);
TestCase.assertTrue("originalPredictErr:" + originalPredictError,
originalPredictError < 0.1f);
float loadedPredictError = eval.eval(loadedBooster.predict(testMat, true), testMat);
TestCase.assertTrue("loadedPredictErr:" + loadedPredictError, loadedPredictError < 0.1f);
}
/**

View File

@@ -16,10 +16,14 @@
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.java.XGBoostError
import java.io.{FileOutputStream, FileInputStream, File}
import junit.framework.TestCase
import org.apache.commons.logging.LogFactory
import org.scalatest.FunSuite
import ml.dmlc.xgboost4j.java.XGBoostError
class ScalaBoosterImplSuite extends FunSuite {
private class EvalError extends EvalTrait {
@@ -64,21 +68,58 @@ class ScalaBoosterImplSuite extends FunSuite {
}
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
private def trainBooster(trainMat: DMatrix, testMat: DMatrix): Booster = {
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
XGBoost.train(paramMap, trainMat, round, watches, null, null)
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val booster = trainBooster(trainMat, testMat)
val predicts = booster.predict(testMat, true)
val eval = new EvalError
assert(eval.eval(predicts, testMat) < 0.1)
}
test("save/load model with path") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
val temp: File = File.createTempFile("temp", "model")
temp.deleteOnExit()
booster.saveModel(temp.getAbsolutePath)
val bst2: Booster = XGBoost.loadModel(temp.getAbsolutePath)
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
}
test("save/load model with stream") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val eval = new EvalError
val booster = trainBooster(trainMat, testMat)
// save and load
val temp: File = File.createTempFile("temp", "model")
temp.deleteOnExit()
booster.saveModel(new FileOutputStream(temp.getAbsolutePath))
val bst2: Booster = XGBoost.loadModel(new FileInputStream(temp.getAbsolutePath))
assert(java.util.Arrays.equals(bst2.toByteArray, booster.toByteArray))
val predicts2: Array[Array[Float]] = bst2.predict(testMat, true, 0)
TestCase.assertTrue(eval.eval(predicts2, testMat) < 0.1f)
}
test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",