diff --git a/jvm-packages/checkstyle-suppressions.xml b/jvm-packages/checkstyle-suppressions.xml index 28aed4fab..4a0f9ab33 100644 --- a/jvm-packages/checkstyle-suppressions.xml +++ b/jvm-packages/checkstyle-suppressions.xml @@ -29,5 +29,5 @@ +files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/> diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index fd336a9c2..4ad951567 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -21,7 +21,7 @@ import java.util.{Iterator => JIterator} import scala.collection.mutable.ListBuffer import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.DataBatch +import ml.dmlc.xgboost4j.java.DataBatch import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 96a6210a7..8151e6ccc 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -20,7 +20,7 @@ import scala.collection.immutable.HashMap import scala.collection.JavaConverters._ import com.typesafe.config.Config -import ml.dmlc.xgboost4j.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.spark.SparkContext import org.apache.spark.mllib.regression.LabeledPoint @@ -48,7 +48,9 @@ object XGBoost { val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples) val dMatrix = new DMatrix(new JDMatrix(dataBatches, null)) Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)) - } + }.cache() + // force the job + sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters) // TODO: how to choose best model boosters.first() } diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index 849ad6168..47efb053f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark import scala.collection.JavaConverters._ -import ml.dmlc.xgboost4j.{DMatrix => JDMatrix} +import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java similarity index 100% rename from jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XgboostJNI.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java