add test cases for Scala API
This commit is contained in:
parent
f8fff6c6fc
commit
5e309f1ce8
@ -95,13 +95,25 @@
|
|||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>2.19.1</version>
|
<version>2.19.1</version>
|
||||||
<configuration>
|
|
||||||
<argLine>-Djava.library.path=lib/</argLine>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.scala-lang</groupId>
|
||||||
|
<artifactId>scala-compiler</artifactId>
|
||||||
|
<version>${scala.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.scala-lang</groupId>
|
||||||
|
<artifactId>scala-reflect</artifactId>
|
||||||
|
<version>${scala.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.scala-lang</groupId>
|
||||||
|
<artifactId>scala-library</artifactId>
|
||||||
|
<version>${scala.version}</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>commons-logging</groupId>
|
<groupId>commons-logging</groupId>
|
||||||
<artifactId>commons-logging</artifactId>
|
<artifactId>commons-logging</artifactId>
|
||||||
@ -113,5 +125,10 @@
|
|||||||
<version>2.2.6</version>
|
<version>2.2.6</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.typesafe</groupId>
|
||||||
|
<artifactId>config</artifactId>
|
||||||
|
<version>1.3.0</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@ -134,21 +134,6 @@ This file is divided into 3 sections:
|
|||||||
<!-- ??? usually shouldn't be checked into the code base. -->
|
<!-- ??? usually shouldn't be checked into the code base. -->
|
||||||
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
|
<check level="error" class="org.scalastyle.scalariform.NotImplementedErrorUsage" enabled="true"></check>
|
||||||
|
|
||||||
<!-- As of SPARK-7558, all tests in Spark should extend o.a.s.SparkFunSuite instead of FunSuite directly -->
|
|
||||||
<check customId="funsuite" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
|
||||||
<parameters><parameter name="regex">^FunSuite[A-Za-z]*$</parameter></parameters>
|
|
||||||
<customMessage>Tests must extend org.apache.spark.SparkFunSuite instead.</customMessage>
|
|
||||||
</check>
|
|
||||||
|
|
||||||
<!-- As of SPARK-7977 all printlns need to be wrapped in '// scalastyle:off/on println' -->
|
|
||||||
<check customId="println" level="error" class="org.scalastyle.scalariform.TokenChecker" enabled="true">
|
|
||||||
<parameters><parameter name="regex">^println$</parameter></parameters>
|
|
||||||
<customMessage><![CDATA[Are you sure you want to println? If yes, wrap the code block with
|
|
||||||
// scalastyle:off println
|
|
||||||
println(...)
|
|
||||||
// scalastyle:on println]]></customMessage>
|
|
||||||
</check>
|
|
||||||
|
|
||||||
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
<check customId="visiblefortesting" level="error" class="org.scalastyle.file.RegexChecker" enabled="true">
|
||||||
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
|
<parameters><parameter name="regex">@VisibleForTesting</parameter></parameters>
|
||||||
<customMessage><![CDATA[
|
<customMessage><![CDATA[
|
||||||
|
|||||||
@ -22,6 +22,19 @@
|
|||||||
<nohelp>true</nohelp>
|
<nohelp>true</nohelp>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
|
<plugin>
|
||||||
|
<groupId>org.scalatest</groupId>
|
||||||
|
<artifactId>scalatest-maven-plugin</artifactId>
|
||||||
|
<version>1.0</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>test</id>
|
||||||
|
<goals>
|
||||||
|
<goal>test</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
</plugin>
|
||||||
</plugins>
|
</plugins>
|
||||||
</build>
|
</build>
|
||||||
<dependencies>
|
<dependencies>
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
|
|||||||
|
|
||||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
|
||||||
|
|
||||||
class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
|
class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* init DMatrix from file (svmlight format)
|
* init DMatrix from file (svmlight format)
|
||||||
|
|||||||
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.IEvaluation
|
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
|
||||||
|
|
||||||
trait EvalTrait extends IEvaluation {
|
trait EvalTrait extends IEvaluation {
|
||||||
|
|
||||||
@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation {
|
|||||||
* @return result of the metric
|
* @return result of the metric
|
||||||
*/
|
*/
|
||||||
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
|
||||||
|
|
||||||
|
private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
|
||||||
|
eval(predicts, new DMatrix(jdmat))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,9 @@
|
|||||||
|
|
||||||
package ml.dmlc.xgboost4j.scala
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.IObjective
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
|
||||||
|
|
||||||
trait ObjectiveTrait extends IObjective {
|
trait ObjectiveTrait extends IObjective {
|
||||||
/**
|
/**
|
||||||
@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective {
|
|||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @return List with two float array, correspond to first order grad and second order grad
|
* @return List with two float array, correspond to first order grad and second order grad
|
||||||
*/
|
*/
|
||||||
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
|
def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
|
||||||
|
|
||||||
|
private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
|
||||||
|
java.util.List[Array[Float]] = {
|
||||||
|
getGradient(predicts, new DMatrix(dtrain)).asJava
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -80,7 +80,7 @@ public class BoosterImplTest {
|
|||||||
};
|
};
|
||||||
|
|
||||||
//set watchList
|
//set watchList
|
||||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
|
||||||
|
|
||||||
watches.put("train", trainMat);
|
watches.put("train", trainMat);
|
||||||
watches.put("test", testMat);
|
watches.put("test", testMat);
|
||||||
@ -129,10 +129,6 @@ public class BoosterImplTest {
|
|||||||
//do 5-fold cross validation
|
//do 5-fold cross validation
|
||||||
int round = 2;
|
int round = 2;
|
||||||
int nfold = 5;
|
int nfold = 5;
|
||||||
//set additional eval_metrics
|
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
|
||||||
String[] metrics = null;
|
|
||||||
|
|
||||||
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
|
|
||||||
null, null);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,7 +57,6 @@ public class DMatrixTest {
|
|||||||
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
long[] rowHeaders = new long[]{0, 3, 7, 11};
|
||||||
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
||||||
//check row num
|
//check row num
|
||||||
System.out.println(dmat1.rowNum());
|
|
||||||
TestCase.assertTrue(dmat1.rowNum() == 3);
|
TestCase.assertTrue(dmat1.rowNum() == 3);
|
||||||
//test set label
|
//test set label
|
||||||
float[] label1 = new float[]{1, 0, 1};
|
float[] label1 = new float[]{1, 0, 1};
|
||||||
|
|||||||
@ -0,0 +1,85 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import java.util.Arrays
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class DMatrixSuite extends FunSuite {
|
||||||
|
test("create DMatrix from File") {
|
||||||
|
val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
// get label
|
||||||
|
val labels: Array[Float] = dmat.getLabel
|
||||||
|
// check length
|
||||||
|
assert(dmat.rowNum === labels.length)
|
||||||
|
// set weights
|
||||||
|
val weights: Array[Float] = Arrays.copyOf(labels, labels.length)
|
||||||
|
dmat.setWeight(weights)
|
||||||
|
val dweights: Array[Float] = dmat.getWeight
|
||||||
|
assert(weights === dweights)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("create DMatrix from CSR") {
|
||||||
|
// create Matrix from csr format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2 3 0
|
||||||
|
* 4 0 2 3 5
|
||||||
|
* 3 1 2 5 0
|
||||||
|
*/
|
||||||
|
val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
|
||||||
|
val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
|
||||||
|
val rowHeaders = List[Long](0, 3, 7, 11).toArray
|
||||||
|
val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR)
|
||||||
|
assert(dmat1.rowNum === 3)
|
||||||
|
val label1 = List[Float](1, 0, 1).toArray
|
||||||
|
dmat1.setLabel(label1)
|
||||||
|
val label2 = dmat1.getLabel
|
||||||
|
assert(label2 === label1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("create DMatrix from DenseMatrix") {
|
||||||
|
val nrow = 10
|
||||||
|
val ncol = 5
|
||||||
|
val data0 = new Array[Float](nrow * ncol)
|
||||||
|
// put random nums
|
||||||
|
for (i <- data0.indices) {
|
||||||
|
data0(i) = Random.nextFloat()
|
||||||
|
}
|
||||||
|
// create label
|
||||||
|
val label0 = new Array[Float](nrow)
|
||||||
|
for (i <- label0.indices) {
|
||||||
|
label0(i) = Random.nextFloat()
|
||||||
|
}
|
||||||
|
val dmat0 = new DMatrix(data0, nrow, ncol)
|
||||||
|
dmat0.setLabel(label0)
|
||||||
|
// check
|
||||||
|
assert(dmat0.rowNum === 10)
|
||||||
|
assert(dmat0.getLabel.length === 10)
|
||||||
|
// set weights for each instance
|
||||||
|
val weights = new Array[Float](nrow)
|
||||||
|
for (i <- weights.indices) {
|
||||||
|
weights(i) = Random.nextFloat()
|
||||||
|
}
|
||||||
|
dmat0.setWeight(weights)
|
||||||
|
assert(weights === dmat0.getWeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,90 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package ml.dmlc.xgboost4j.scala
|
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.XGBoostError
|
||||||
|
import org.apache.commons.logging.LogFactory
|
||||||
|
import org.scalatest.FunSuite
|
||||||
|
|
||||||
|
class ScalaBoosterImplSuite extends FunSuite {
|
||||||
|
|
||||||
|
private class EvalError extends EvalTrait {
|
||||||
|
|
||||||
|
val logger = LogFactory.getLog(classOf[EvalError])
|
||||||
|
|
||||||
|
private[xgboost4j] var evalMetric: String = "custom_error"
|
||||||
|
|
||||||
|
/**
|
||||||
|
* get evaluate metric
|
||||||
|
*
|
||||||
|
* @return evalMetric
|
||||||
|
*/
|
||||||
|
override def getMetric: String = evalMetric
|
||||||
|
|
||||||
|
/**
|
||||||
|
* evaluate with predicts and data
|
||||||
|
*
|
||||||
|
* @param predicts predictions as array
|
||||||
|
* @param dmat data matrix to evaluate
|
||||||
|
* @return result of the metric
|
||||||
|
*/
|
||||||
|
override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
|
||||||
|
var error: Float = 0f
|
||||||
|
var labels: Array[Float] = null
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel
|
||||||
|
} catch {
|
||||||
|
case ex: XGBoostError =>
|
||||||
|
logger.error(ex)
|
||||||
|
return -1f
|
||||||
|
}
|
||||||
|
val nrow: Int = predicts.length
|
||||||
|
for (i <- 0 until nrow) {
|
||||||
|
if (labels(i) == 0.0 && predicts(i)(0) > 0) {
|
||||||
|
error += 1
|
||||||
|
} else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
|
||||||
|
error += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error / labels.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("basic operation of booster") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
|
||||||
|
|
||||||
|
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
|
||||||
|
"objective" -> "binary:logistic").toMap
|
||||||
|
val watches = List("train" -> trainMat, "test" -> testMat).toMap
|
||||||
|
|
||||||
|
val round = 2
|
||||||
|
val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
|
||||||
|
val predicts = booster.predict(testMat, true)
|
||||||
|
val eval = new EvalError
|
||||||
|
assert(eval.eval(predicts, testMat) < 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("cross validation") {
|
||||||
|
val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
|
||||||
|
val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
|
||||||
|
"objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
|
||||||
|
val round = 2
|
||||||
|
val nfold = 5
|
||||||
|
XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user