add test cases for Scala API
This commit is contained in:
parent
f8fff6c6fc
commit
5e309f1ce8
@ -95,13 +95,25 @@
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>2.19.1</version>
|
||||
<configuration>
|
||||
<argLine>-Djava.library.path=lib/</argLine>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-compiler</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-logging</groupId>
|
||||
<artifactId>commons-logging</artifactId>
|
||||
@ -113,5 +125,10 @@
|
||||
<version>2.2.6</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe</groupId>
|
||||
<artifactId>config</artifactId>
|
||||
<version>1.3.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
|
||||
@ -134,21 +134,6 @@ This file is divided into 3 sections:
|
||||
<!-- ??? usually shouldn't be checked into the code base. -->
|
||||
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
|
||||
|
||||
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
|
||||
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
|
||||
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
|
||||
</check>
|
||||
|
||||
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
|
||||
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
||||
<parameters><parameter name="regex">^println$</parameter></parameters>
|
||||
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
|
||||
// scalastyle:off println
|
||||
println(...)
|
||||
// scalastyle:on println]]></customMessage>
|
||||
</check>
|
||||
|
||||
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
|
||||
<customMessage><![CDATA[
|
||||
|
||||
@ -22,6 +22,19 @@
|
||||
<nohelp>true</nohelp>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest-maven-plugin</artifactId>
|
||||
<version>1.0</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>test</id>
|
||||
<goals>
|
||||
<goal>test</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
<dependencies>
|
||||
|
||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||
|
||||
class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
|
||||
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.IEvaluation
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
|
||||
|
||||
trait EvalTrait extends IEvaluation {
|
||||
|
||||
@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation {
|
||||
* @return result of the metric
|
||||
*/
|
||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||
|
||||
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
|
||||
eval(predicts, new DMatrix(jdmat))
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,7 +16,9 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.IObjective
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
|
||||
|
||||
trait ObjectiveTrait extends IObjective {
|
||||
/**
|
||||
@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective {
|
||||
* @param dtrain training data
|
||||
* @return List with two float array, correspond to first order grad and second order grad
|
||||
*/
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
|
||||
|
||||
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||
java.util.List[Array[Float]] = {
|
||||
getGradient(predicts, new DMatrix(dtrain)).asJava
|
||||
}
|
||||
}
|
||||
|
||||
@ -80,7 +80,7 @@ public class BoosterImplTest {
|
||||
};
|
||||
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||
|
||||
watches.put("train", trainMat);
|
||||
watches.put("test", testMat);
|
||||
@ -129,10 +129,6 @@ public class BoosterImplTest {
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
|
||||
null, null);
|
||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
@ -57,7 +57,6 @@ public class DMatrixTest {
|
||||
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum() == 3);
|
||||
//test set label
|
||||
float[] label1 = new float[]{1, 0, 1};
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.util.Arrays
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class DMatrixSuite extends FunSuite {
|
||||
test("create DMatrix from File") {
|
||||
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
// get label
|
||||
val labels: Array[Float] = dmat.getLabel
|
||||
// check length
|
||||
assert(dmat.rowNum === labels.length)
|
||||
// set weights
|
||||
val weights: Array[Float] = Arrays.copyOf(labels, labels.length)
|
||||
dmat.setWeight(weights)
|
||||
val dweights: Array[Float] = dmat.getWeight
|
||||
assert(weights === dweights)
|
||||
}
|
||||
|
||||
test("create DMatrix from CSR") {
|
||||
// create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
|
||||
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||
val rowHeaders = List[Long](0, 3, 7, 11).toArray
|
||||
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR)
|
||||
assert(dmat1.rowNum === 3)
|
||||
val label1 = List[Float](1, 0, 1).toArray
|
||||
dmat1.setLabel(label1)
|
||||
val label2 = dmat1.getLabel
|
||||
assert(label2 === label1)
|
||||
}
|
||||
|
||||
test("create DMatrix from DenseMatrix") {
|
||||
val nrow = 10
|
||||
val ncol = 5
|
||||
val data0 = new Array[Float](nrow * ncol)
|
||||
// put random nums
|
||||
for (i <- data0.indices) {
|
||||
data0(i) = Random.nextFloat()
|
||||
}
|
||||
// create label
|
||||
val label0 = new Array[Float](nrow)
|
||||
for (i <- label0.indices) {
|
||||
label0(i) = Random.nextFloat()
|
||||
}
|
||||
val dmat0 = new DMatrix(data0, nrow, ncol)
|
||||
dmat0.setLabel(label0)
|
||||
// check
|
||||
assert(dmat0.rowNum === 10)
|
||||
assert(dmat0.getLabel.length === 10)
|
||||
// set weights for each instance
|
||||
val weights = new Array[Float](nrow)
|
||||
for (i <- weights.indices) {
|
||||
weights(i) = Random.nextFloat()
|
||||
}
|
||||
dmat0.setWeight(weights)
|
||||
assert(weights === dmat0.getWeight)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,90 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import ml.dmlc.xgboost4j.XGBoostError
|
||||
import org.apache.commons.logging.LogFactory
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
class ScalaBoosterImplSuite extends FunSuite {
|
||||
|
||||
private class EvalError extends EvalTrait {
|
||||
|
||||
val logger = LogFactory.getLog(classOf[EvalError])
|
||||
|
||||
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||
|
||||
/**
|
||||
* get evaluate metric
|
||||
*
|
||||
* @return evalMetric
|
||||
*/
|
||||
override def getMetric: String = evalMetric
|
||||
|
||||
/**
|
||||
* evaluate with predicts and data
|
||||
*
|
||||
* @param predicts predictions as array
|
||||
* @param dmat data matrix to evaluate
|
||||
* @return result of the metric
|
||||
*/
|
||||
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||
var error: Float = 0f
|
||||
var labels: Array[Float] = null
|
||||
try {
|
||||
labels = dmat.getLabel
|
||||
} catch {
|
||||
case ex: XGBoostError =>
|
||||
logger.error(ex)
|
||||
return -1f
|
||||
}
|
||||
val nrow: Int = predicts.length
|
||||
for (i <- 0 until nrow) {
|
||||
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||
error += 1
|
||||
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||
error += 1
|
||||
}
|
||||
}
|
||||
error / labels.length
|
||||
}
|
||||
}
|
||||
|
||||
test("basic operation of booster") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||
|
||||
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||
"objective" -> "binary:logistic").toMap
|
||||
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||
|
||||
val round = 2
|
||||
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||
val predicts = booster.predict(testMat, true)
|
||||
val eval = new EvalError
|
||||
assert(eval.eval(predicts, testMat) < 0.1)
|
||||
}
|
||||
|
||||
test("cross validation") {
|
||||
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
|
||||
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||
val round = 2
|
||||
val nfold = 5
|
||||
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user