This commit is contained in:
CodingCat 2016-03-05 13:44:55 -05:00
parent e8560c7909
commit bb43177eb1
5 changed files with 7 additions and 5 deletions

View File

@ -29,5 +29,5 @@
<suppressions>
<suppress checks=".*"
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
</suppressions>

View File

@ -21,7 +21,7 @@ import java.util.{Iterator => JIterator}
import scala.collection.mutable.ListBuffer
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.DataBatch
import ml.dmlc.xgboost4j.java.DataBatch
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint

View File

@ -20,7 +20,7 @@ import scala.collection.immutable.HashMap
import scala.collection.JavaConverters._
import com.typesafe.config.Config
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
@ -48,7 +48,9 @@ object XGBoost {
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
}
}.cache()
// force the job
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
// TODO: how to choose best model
boosters.first()
}

View File

@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
import scala.collection.JavaConverters._
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD