diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml
index 5ec221175..beb2f96f2 100644
--- a/jvm-packages/pom.xml
+++ b/jvm-packages/pom.xml
@@ -95,13 +95,25 @@
org.apache.maven.plugins
maven-surefire-plugin
2.19.1
-
- -Djava.library.path=lib/
-
+
+ org.scala-lang
+ scala-compiler
+ ${scala.version}
+
+
+ org.scala-lang
+ scala-reflect
+ ${scala.version}
+
+
+ org.scala-lang
+ scala-library
+ ${scala.version}
+
commons-logging
commons-logging
@@ -113,5 +125,10 @@
2.2.6
test
+
+ com.typesafe
+ config
+ 1.3.0
+
diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml
index 27bb4fa8a..204b72a20 100644
--- a/jvm-packages/scalastyle-config.xml
+++ b/jvm-packages/scalastyle-config.xml
@@ -134,21 +134,6 @@ This file is divided into 3 sections:
-
-
- ^FunSuite[A-Za-z]*$
- Tests must extend org.apache.spark.SparkFunSuite instead.
-
-
-
-
- ^println$
-
-
-
@VisibleForTesting
true
+
+ org.scalatest
+ scalatest-maven-plugin
+ 1.0
+
+
+ test
+
+ test
+
+
+
+
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
index 73fafc7f0..634aef190 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala
@@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError}
-class DMatrix private(private[scala] val jDMatrix: JDMatrix) {
+class DMatrix private[scala](private[scala] val jDMatrix: JDMatrix) {
/**
* init DMatrix from file (svmlight format)
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala
index 461f515a1..5f4e85683 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala
@@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j.scala
-import ml.dmlc.xgboost4j.IEvaluation
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IEvaluation}
trait EvalTrait extends IEvaluation {
@@ -35,4 +35,8 @@ trait EvalTrait extends IEvaluation {
* @return result of the metric
*/
def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float
+
+ private[scala] def eval(predicts: Array[Array[Float]], jdmat: JDMatrix): Float = {
+ eval(predicts, new DMatrix(jdmat))
+ }
}
diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala
index c5df8aead..8f7bb86f0 100644
--- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala
+++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala
@@ -16,7 +16,9 @@
package ml.dmlc.xgboost4j.scala
-import ml.dmlc.xgboost4j.IObjective
+import scala.collection.JavaConverters._
+
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, IObjective}
trait ObjectiveTrait extends IObjective {
/**
@@ -26,5 +28,10 @@ trait ObjectiveTrait extends IObjective {
* @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad
*/
- def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]]
+ def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): List[Array[Float]]
+
+ private[scala] def getGradient(predicts: Array[Array[Float]], dtrain: JDMatrix):
+ java.util.List[Array[Float]] = {
+ getGradient(predicts, new DMatrix(dtrain)).asJava
+ }
}
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
index e44bc95bc..59dca16f5 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java
@@ -80,7 +80,7 @@ public class BoosterImplTest {
};
//set watchList
- HashMap watches = new HashMap<>();
+ HashMap watches = new HashMap();
watches.put("train", trainMat);
watches.put("test", testMat);
@@ -129,10 +129,6 @@ public class BoosterImplTest {
//do 5-fold cross validation
int round = 2;
int nfold = 5;
- //set additional eval_metrics
- String[] metrics = null;
-
- String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
- null, null);
+ String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, null, null, null);
}
}
diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java
index 9b3a8b860..72062a9c4 100644
--- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java
+++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java
@@ -57,7 +57,6 @@ public class DMatrixTest {
long[] rowHeaders = new long[]{0, 3, 7, 11};
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
//check row num
- System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label
float[] label1 = new float[]{1, 0, 1};
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala
new file mode 100644
index 000000000..64ec3e033
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/DMatrixSuite.scala
@@ -0,0 +1,85 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import java.util.Arrays
+
+import scala.util.Random
+
+import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
+import org.scalatest.FunSuite
+
+class DMatrixSuite extends FunSuite {
+ test("create DMatrix from File") {
+ val dmat = new DMatrix("../../demo/data/agaricus.txt.test")
+ // get label
+ val labels: Array[Float] = dmat.getLabel
+ // check length
+ assert(dmat.rowNum === labels.length)
+ // set weights
+ val weights: Array[Float] = Arrays.copyOf(labels, labels.length)
+ dmat.setWeight(weights)
+ val dweights: Array[Float] = dmat.getWeight
+ assert(weights === dweights)
+ }
+
+ test("create DMatrix from CSR") {
+ // create Matrix from csr format sparse Matrix and labels
+ /**
+ * sparse matrix
+ * 1 0 2 3 0
+ * 4 0 2 3 5
+ * 3 1 2 5 0
+ */
+ val data = List[Float](1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5).toArray
+ val colIndex = List(0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3).toArray
+ val rowHeaders = List[Long](0, 3, 7, 11).toArray
+ val dmat1 = new DMatrix(rowHeaders, colIndex, data, JDMatrix.SparseType.CSR)
+ assert(dmat1.rowNum === 3)
+ val label1 = List[Float](1, 0, 1).toArray
+ dmat1.setLabel(label1)
+ val label2 = dmat1.getLabel
+ assert(label2 === label1)
+ }
+
+ test("create DMatrix from DenseMatrix") {
+ val nrow = 10
+ val ncol = 5
+ val data0 = new Array[Float](nrow * ncol)
+ // put random nums
+ for (i <- data0.indices) {
+ data0(i) = Random.nextFloat()
+ }
+ // create label
+ val label0 = new Array[Float](nrow)
+ for (i <- label0.indices) {
+ label0(i) = Random.nextFloat()
+ }
+ val dmat0 = new DMatrix(data0, nrow, ncol)
+ dmat0.setLabel(label0)
+ // check
+ assert(dmat0.rowNum === 10)
+ assert(dmat0.getLabel.length === 10)
+ // set weights for each instance
+ val weights = new Array[Float](nrow)
+ for (i <- weights.indices) {
+ weights(i) = Random.nextFloat()
+ }
+ dmat0.setWeight(weights)
+ assert(weights === dmat0.getWeight)
+ }
+}
diff --git a/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
new file mode 100644
index 000000000..e911ec985
--- /dev/null
+++ b/jvm-packages/xgboost4j/src/test/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImplSuite.scala
@@ -0,0 +1,90 @@
+/*
+ Copyright (c) 2014 by Contributors
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package ml.dmlc.xgboost4j.scala
+
+import ml.dmlc.xgboost4j.XGBoostError
+import org.apache.commons.logging.LogFactory
+import org.scalatest.FunSuite
+
+class ScalaBoosterImplSuite extends FunSuite {
+
+ private class EvalError extends EvalTrait {
+
+ val logger = LogFactory.getLog(classOf[EvalError])
+
+ private[xgboost4j] var evalMetric: String = "custom_error"
+
+ /**
+ * get evaluate metric
+ *
+ * @return evalMetric
+ */
+ override def getMetric: String = evalMetric
+
+ /**
+ * evaluate with predicts and data
+ *
+ * @param predicts predictions as array
+ * @param dmat data matrix to evaluate
+ * @return result of the metric
+ */
+ override def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float = {
+ var error: Float = 0f
+ var labels: Array[Float] = null
+ try {
+ labels = dmat.getLabel
+ } catch {
+ case ex: XGBoostError =>
+ logger.error(ex)
+ return -1f
+ }
+ val nrow: Int = predicts.length
+ for (i <- 0 until nrow) {
+ if (labels(i) == 0.0 && predicts(i)(0) > 0) {
+ error += 1
+ } else if (labels(i) == 1.0 && predicts(i)(0) <= 0) {
+ error += 1
+ }
+ }
+ error / labels.length
+ }
+ }
+
+ test("basic operation of booster") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val testMat = new DMatrix("../../demo/data/agaricus.txt.test")
+
+ val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
+ "objective" -> "binary:logistic").toMap
+ val watches = List("train" -> trainMat, "test" -> testMat).toMap
+
+ val round = 2
+ val booster = XGBoost.train(paramMap, trainMat, round, watches, null, null)
+ val predicts = booster.predict(testMat, true)
+ val eval = new EvalError
+ assert(eval.eval(predicts, testMat) < 0.1)
+ }
+
+ test("cross validation") {
+ val trainMat = new DMatrix("../../demo/data/agaricus.txt.train")
+ val params = List("eta" -> "1.0", "max_depth" -> "3", "slient" -> "1", "nthread" -> "6",
+ "objective" -> "binary:logistic", "gamma" -> "1.0", "eval_metric" -> "error").toMap
+ val round = 2
+ val nfold = 5
+ XGBoost.crossValiation(params, trainMat, round, nfold, null, null, null)
+ }
+}