[jvm-packages][xgboost4j-spark][Minor] Move sparkContext dependency from the XGBoostModel (#1335)

* Move sparkContext dependency from the XGBoostModel

* Update Spark example to declare SparkContext as implict
This commit is contained in:
Rahul 2016-07-08 16:13:33 +05:30 committed by Nan Zhu
parent 3f32b3f0eb
commit f14c160f4f
2 changed files with 3 additions and 3 deletions

View File

@ -31,7 +31,7 @@ object DistTrainWithSpark {
val sparkConf = new SparkConf().setAppName("XGBoost-spark-example") val sparkConf = new SparkConf().setAppName("XGBoost-spark-example")
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
sparkConf.registerKryoClasses(Array(classOf[Booster])) sparkConf.registerKryoClasses(Array(classOf[Booster]))
val sc = new SparkContext(sparkConf) implicit val sc = new SparkContext(sparkConf)
val inputTrainPath = args(2) val inputTrainPath = args(2)
val inputTestPath = args(3) val inputTestPath = args(3)
val outputModelPath = args(4) val outputModelPath = args(4)

View File

@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix} import ml.dmlc.xgboost4j.java.{DMatrix => JDMatrix}
import ml.dmlc.xgboost4j.scala.{DMatrix, Booster} import ml.dmlc.xgboost4j.scala.{DMatrix, Booster}
class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Serializable { class XGBoostModel(_booster: Booster) extends Serializable {
/** /**
* Predict result with the given testset (represented as RDD) * Predict result with the given testset (represented as RDD)
@ -89,7 +89,7 @@ class XGBoostModel(_booster: Booster)(implicit val sc: SparkContext) extends Ser
* *
* @param modelPath The model path as in Hadoop path. * @param modelPath The model path as in Hadoop path.
*/ */
def saveModelAsHadoopFile(modelPath: String): Unit = { def saveModelAsHadoopFile(modelPath: String)(implicit sc: SparkContext): Unit = {
val path = new Path(modelPath) val path = new Path(modelPath)
val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path) val outputStream = path.getFileSystem(sc.hadoopConfiguration).create(path)
_booster.saveModel(outputStream) _booster.saveModel(outputStream)