From 5e309f1ce8c936920de1ffcab4c9f698f30ce32c Mon Sep 17 00:00:00 2001 From: CodingCat Date: Wed, 2 Mar 2016 15:24:13 -0500 Subject: [PATCH] add test cases for Scala API --- jvm-packages/pom.xml | 23 ++++- jvm-packages/scalastyle-config.xml | 15 ---- jvm-packages/xgboost4j/pom.xml | 13 +++ .../ml/dmlc/xgboost4j/scala/DMatrix.scala | 2 +- .../ml/dmlc/xgboost4j/scala/EvalTrait.scala | 6 +- .../dmlc/xgboost4j/scala/ObjectiveTrait.scala | 11 ++- .../ml/dmlc/xgboost4j/BoosterImplTest.java | 8 +- .../java/ml/dmlc/xgboost4j/DMatrixTest.java | 1 - .../dmlc/xgboost4j/scala/DMatrixSuite.scala | 85 ++++++++++++++++++ .../scala/ScalaBoosterImplSuite.scala | 90 +++++++++++++++++++ 10 files changed, 225 insertions(+), 29 deletions(-) create mode 100644 jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala create mode 100644 jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 5ec221175..beb2f96f2 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -95,13 +95,25 @@ org.apache.maven.plugins maven-surefire-plugin 2.19.1 - - -Djava.library.path=lib/ - + + org.scala-lang + scala-compiler + ${scala.version} + + + org.scala-lang + scala-reflect + ${scala.version} + + + org.scala-lang + scala-library + ${scala.version} + commons-logging commons-logging @@ -113,5 +125,10 @@ 2.2.6 test + + com.typesafe + config + 1.3.0 + diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml index 27bb4fa8a..204b72a20 100644 --- a/jvm-packages/scalastyle-config.xml +++ b/jvm-packages/scalastyle-config.xml @@ -134,21 +134,6 @@ This file is divided into 3 sections: - - - ^FunSuite[A-Za-z]*$ - Tests must extend org.apache.spark.SparkFunSuite instead. - - - - - ^println$ - - - @VisibleForTesting true + + org.scalatest + scalatest-maven-plugin + 1.0 + + + test + + test + + + + diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala index 73fafc7f0..634aef190 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError} -class DMatrix private(private[scala] val jDMatrix: JDMatrix) { +class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) { /** * init DMatrix from file (svmlight format) diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala index 461f515a1..5f4e85683 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala @@ -16,7 +16,7 @@ package ml.dmlc.xgboost4j.scala -import ml.dmlc.xgboost4j.IEvaluation +import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation} trait EvalTrait extends IEvaluation { @@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation { * @return result of the metric */ def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float + + private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = { + eval(predicts, new DMatrix(jdmat)) + } } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala index c5df8aead..8f7bb86f0 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala @@ -16,7 +16,9 @@ package ml.dmlc.xgboost4j.scala -import ml.dmlc.xgboost4j.IObjective +import scala.collection.JavaConverters._ + +import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective} trait ObjectiveTrait extends IObjective { /** @@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective { * @param dtrain training data * @return List with two float array, correspond to first order grad and second order grad */ - def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]] + def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]] + + private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix): + java.util.List[Array[Float]] = { + getGradient(predicts, new DMatrix(dtrain)).asJava + } } diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java index e44bc95bc..59dca16f5 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java @@ -80,7 +80,7 @@ public class BoosterImplTest { }; //set watchList - HashMap watches = new HashMap<>(); + HashMap watches = new HashMap(); watches.put("train", trainMat); watches.put("test", testMat); @@ -129,10 +129,6 @@ public class BoosterImplTest { //do 5-fold cross validation int round = 2; int nfold = 5; - //set additional eval_metrics - String[] metrics = null; - - String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics, - null, null); + String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null); } } diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java index 9b3a8b860..72062a9c4 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java @@ -57,7 +57,6 @@ public class DMatrixTest { long[] rowHeaders = new long[]{0, 3, 7, 11}; DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); //check row num - System.out.println(dmat1.rowNum()); TestCase.assertTrue(dmat1.rowNum() == 3); //test set label float[] label1 = new float[]{1, 0, 1}; diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala new file mode 100644 index 000000000..64ec3e033 --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala @@ -0,0 +1,85 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import java.util.Arrays + +import scala.util.Random + +import ml.dmlc.xgboost4j.{DMatrix => JDMatrix} +import org.scalatest.FunSuite + +class DMatrixSuite extends FunSuite { + test("create DMatrix from File") { + val dmat = new DMatrix("../../demo/data/agaricus.txt.test") + // get label + val labels: Array[Float] = dmat.getLabel + // check length + assert(dmat.rowNum === labels.length) + // set weights + val weights: Array[Float] = Arrays.copyOf(labels, labels.length) + dmat.setWeight(weights) + val dweights: Array[Float] = dmat.getWeight + assert(weights === dweights) + } + + test("create DMatrix from CSR") { + // create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray + val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray + val rowHeaders = List[Long](0, 3, 7, 11).toArray + val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR) + assert(dmat1.rowNum === 3) + val label1 = List[Float](1, 0, 1).toArray + dmat1.setLabel(label1) + val label2 = dmat1.getLabel + assert(label2 === label1) + } + + test("create DMatrix from DenseMatrix") { + val nrow = 10 + val ncol = 5 + val data0 = new Array[Float](nrow * ncol) + // put random nums + for (i <- data0.indices) { + data0(i) = Random.nextFloat() + } + // create label + val label0 = new Array[Float](nrow) + for (i <- label0.indices) { + label0(i) = Random.nextFloat() + } + val dmat0 = new DMatrix(data0, nrow, ncol) + dmat0.setLabel(label0) + // check + assert(dmat0.rowNum === 10) + assert(dmat0.getLabel.length === 10) + // set weights for each instance + val weights = new Array[Float](nrow) + for (i <- weights.indices) { + weights(i) = Random.nextFloat() + } + dmat0.setWeight(weights) + assert(weights === dmat0.getWeight) + } +} diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala new file mode 100644 index 000000000..e911ec985 --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala @@ -0,0 +1,90 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import ml.dmlc.xgboost4j.XGBoostError +import org.apache.commons.logging.LogFactory +import org.scalatest.FunSuite + +class ScalaBoosterImplSuite extends FunSuite { + + private class EvalError extends EvalTrait { + + val logger = LogFactory.getLog(classOf[EvalError]) + + private[xgboost4j] var evalMetric: String = "custom_error" + + /** + * get evaluate metric + * + * @return evalMetric + */ + override def getMetric: String = evalMetric + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = { + var error: Float = 0f + var labels: Array[Float] = null + try { + labels = dmat.getLabel + } catch { + case ex: XGBoostError => + logger.error(ex) + return -1f + } + val nrow: Int = predicts.length + for (i <- 0 until nrow) { + if (labels(i) == 0.0 && predicts(i)(0) > 0) { + error += 1 + } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) { + error += 1 + } + } + error / labels.length + } + } + + test("basic operation of booster") { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val testMat = new DMatrix("../../demo/data/agaricus.txt.test") + + val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", + "objective" -> "binary:logistic").toMap + val watches = List("train" -> trainMat, "test" -> testMat).toMap + + val round = 2 + val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null) + val predicts = booster.predict(testMat, true) + val eval = new EvalError + assert(eval.eval(predicts, testMat) < 0.1) + } + + test("cross validation") { + val trainMat = new DMatrix("../../demo/data/agaricus.txt.train") + val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6", + "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap + val round = 2 + val nfold = 5 + XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null) + } +}