add test cases for Scala API

This commit is contained in:
CodingCat 2016-03-02 15:24:13 -05:00
parent f8fff6c6fc
commit 5e309f1ce8
10 changed files with 225 additions and 29 deletions

View File

@ -95,13 +95,25 @@
<groupId>org.apache.maven.plugins</groupId> <groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId> <artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version> <version>2.19.1</version>
<configuration>
<argLine>-Djava.library.path=lib/</argLine>
</configuration>
</plugin> </plugin>
</plugins> </plugins>
</build> </build>
<dependencies> <dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency> <dependency>
<groupId>commons-logging</groupId> <groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId> <artifactId>commons-logging</artifactId>
@ -113,5 +125,10 @@
<version>2.2.6</version> <version>2.2.6</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>1.3.0</version>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -134,21 +134,6 @@ This file is divided into 3 sections:
<!-- ??? usually shouldn't be checked into the code base. --> <!-- ??? usually shouldn't be checked into the code base. -->
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check> <check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
</check>
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
<parameters><parameter name="regex">^println$</parameter></parameters>
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
// scalastyle:off println
println(...)
// scalastyle:on println]]></customMessage>
</check>
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true"> <check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters> <parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
<customMessage><![CDATA[ <customMessage><![CDATA[

View File

@ -22,6 +22,19 @@
<nohelp>true</nohelp> <nohelp>true</nohelp>
</configuration> </configuration>
</plugin> </plugin>
<plugin>
<groupId>org.scalatest</groupId>
<artifactId>scalatest-maven-plugin</artifactId>
<version>1.0</version>
<executions>
<execution>
<id>test</id>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins> </plugins>
</build> </build>
<dependencies> <dependencies>

View File

@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError} import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
class DMatrix private(private[scala] val jDMatrix: JDMatrix) { class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/** /**
* init DMatrix from file (svmlight format) * init DMatrix from file (svmlight format)

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.IEvaluation import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
trait EvalTrait extends IEvaluation { trait EvalTrait extends IEvaluation {
@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation {
* @return result of the metric * @return result of the metric
*/ */
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
eval(predicts, new DMatrix(jdmat))
}
} }

View File

@ -16,7 +16,9 @@
package ml.dmlc.xgboost4j.scala package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.IObjective import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
trait ObjectiveTrait extends IObjective { trait ObjectiveTrait extends IObjective {
/** /**
@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective {
* @param dtrain training data * @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad * @return List with two float array, correspond to first order grad and second order grad
*/ */
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]] def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
java.util.List[Array[Float]] = {
getGradient(predicts, new DMatrix(dtrain)).asJava
}
} }

View File

@ -80,7 +80,7 @@ public class BoosterImplTest {
}; };
//set watchList //set watchList
HashMap<String, DMatrix> watches = new HashMap<>(); HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat); watches.put("train", trainMat);
watches.put("test", testMat); watches.put("test", testMat);
@ -129,10 +129,6 @@ public class BoosterImplTest {
//do 5-fold cross validation //do 5-fold cross validation
int round = 2; int round = 2;
int nfold = 5; int nfold = 5;
//set additional eval_metrics String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
null, null);
} }
} }

View File

@ -57,7 +57,6 @@ public class DMatrixTest {
long[] rowHeaders = new long[]{0, 3, 7, 11}; long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
//check row num //check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 3); TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label //test set label
float[] label1 = new float[]{1, 0, 1}; float[] label1 = new float[]{1, 0, 1};

View File

@ -0,0 +1,85 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import java.util.Arrays
import scala.util.Random
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import org.scalatest.FunSuite
class DMatrixSuite extends FunSuite {
test("create DMatrix from File") {
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
// get label
val labels: Array[Float] = dmat.getLabel
// check length
assert(dmat.rowNum === labels.length)
// set weights
val weights: Array[Float] = Arrays.copyOf(labels, labels.length)
dmat.setWeight(weights)
val dweights: Array[Float] = dmat.getWeight
assert(weights === dweights)
}
test("create DMatrix from CSR") {
// create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
val rowHeaders = List[Long](0, 3, 7, 11).toArray
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR)
assert(dmat1.rowNum === 3)
val label1 = List[Float](1, 0, 1).toArray
dmat1.setLabel(label1)
val label2 = dmat1.getLabel
assert(label2 === label1)
}
test("create DMatrix from DenseMatrix") {
val nrow = 10
val ncol = 5
val data0 = new Array[Float](nrow * ncol)
// put random nums
for (i <- data0.indices) {
data0(i) = Random.nextFloat()
}
// create label
val label0 = new Array[Float](nrow)
for (i <- label0.indices) {
label0(i) = Random.nextFloat()
}
val dmat0 = new DMatrix(data0, nrow, ncol)
dmat0.setLabel(label0)
// check
assert(dmat0.rowNum === 10)
assert(dmat0.getLabel.length === 10)
// set weights for each instance
val weights = new Array[Float](nrow)
for (i <- weights.indices) {
weights(i) = Random.nextFloat()
}
dmat0.setWeight(weights)
assert(weights === dmat0.getWeight)
}
}

View File

@ -0,0 +1,90 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.XGBoostError
import org.apache.commons.logging.LogFactory
import org.scalatest.FunSuite
class ScalaBoosterImplSuite extends FunSuite {
private class EvalError extends EvalTrait {
val logger = LogFactory.getLog(classOf[EvalError])
private[xgboost4j] var evalMetric: String = "custom_error"
/**
* get evaluate metric
*
* @return evalMetric
*/
override def getMetric: String = evalMetric
/**
* evaluate with predicts and data
*
* @param predicts predictions as array
* @param dmat data matrix to evaluate
* @return result of the metric
*/
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
var error: Float = 0f
var labels: Array[Float] = null
try {
labels = dmat.getLabel
} catch {
case ex: XGBoostError =>
logger.error(ex)
return -1f
}
val nrow: Int = predicts.length
for (i <- 0 until nrow) {
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
error += 1
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
error += 1
}
}
error / labels.length
}
}
test("basic operation of booster") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val watches = List("train" -> trainMat, "test" -> testMat).toMap
val round = 2
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
val predicts = booster.predict(testMat, true)
val eval = new EvalError
assert(eval.eval(predicts, testMat) < 0.1)
}
test("cross validation") {
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
val round = 2
val nfold = 5
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
}
}