merge
This commit is contained in:
parent
e8560c7909
commit
bb43177eb1
@ -29,5 +29,5 @@
|
|||||||
|
|
||||||
<suppressions>
|
<suppressions>
|
||||||
<suppress checks=".*"
|
<suppress checks=".*"
|
||||||
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
|
files="xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java"/>
|
||||||
</suppressions>
|
</suppressions>
|
||||||
|
|||||||
@ -21,7 +21,7 @@ import java.util.{Iterator => JIterator}
|
|||||||
import scala.collection.mutable.ListBuffer
|
import scala.collection.mutable.ListBuffer
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.DataBatch
|
import ml.dmlc.xgboost4j.java.DataBatch
|
||||||
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
|
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import scala.collection.immutable.HashMap
|
|||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import com.typesafe.config.Config
|
import com.typesafe.config.Config
|
||||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
@ -48,7 +48,9 @@ object XGBoost {
|
|||||||
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
|
val dataBatches = dataUtilsBroadcast.value.fromLabeledPointsToSparseMatrix(trainingSamples)
|
||||||
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
val dMatrix = new DMatrix(new JDMatrix(dataBatches, null))
|
||||||
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
|
Iterator(SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval))
|
||||||
}
|
}.cache()
|
||||||
|
// force the job
|
||||||
|
sc.runJob(boosters, (boosters: Iterator[Booster]) => boosters)
|
||||||
// TODO: how to choose best model
|
// TODO: how to choose best model
|
||||||
boosters.first()
|
boosters.first()
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark
|
|||||||
|
|
||||||
import scala.collection.JavaConverters._
|
import scala.collection.JavaConverters._
|
||||||
|
|
||||||
import ml.dmlc.xgboost4j.{DMatrix => JDMatrix}
|
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
|
||||||
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
|
||||||
import org.apache.spark.mllib.regression.LabeledPoint
|
import org.apache.spark.mllib.regression.LabeledPoint
|
||||||
import org.apache.spark.rdd.RDD
|
import org.apache.spark.rdd.RDD
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user