sketch of xgboost-spark
chooseBestBooster shall be in Boosters remove tracker.py rename XGBoost remove cross-validation
This commit is contained in:
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
24
jvm-packages/xgboost4jspark/pom.xml
Normal file
@@ -0,0 +1,24 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboostjvm</artifactId>
|
||||
<version>0.1</version>
|
||||
</parent>
|
||||
<artifactId>xgboost4jspark</artifactId>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ml.dmlc</groupId>
|
||||
<artifactId>xgboost4j</artifactId>
|
||||
<version>0.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.10</artifactId>
|
||||
<version>1.6.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@@ -0,0 +1,32 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala
|
||||
|
||||
import java.io.DataInputStream
|
||||
|
||||
private[xgboost4j] object DMatrixBuilder extends Serializable {
|
||||
|
||||
def buildDMatrixfromBinaryData(inStream: DataInputStream): DMatrix = {
|
||||
// TODO: currently it is random statement for making compiler happy
|
||||
new DMatrix(new Array[Float](1), 1, 1)
|
||||
}
|
||||
|
||||
def buildDMatrixfromBinaryData(binaryArray: Array[Byte]): DMatrix = {
|
||||
// TODO: currently it is random statement for making compiler happy
|
||||
new DMatrix(new Array[Float](1), 1, 1)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import ml.dmlc.xgboost4j.scala.Booster
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
class Boosters(boosters: RDD[Booster]) {
|
||||
|
||||
def save(path: String): Unit = {
|
||||
|
||||
}
|
||||
|
||||
def chooseBestBooster(boosters: RDD[Booster]): Booster = {
|
||||
// TODO:
|
||||
null
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object Boosters {
|
||||
|
||||
implicit def boosterRDDToBoosters(boosterRDD: RDD[Booster]): Boosters = {
|
||||
new Boosters(boosterRDD)
|
||||
}
|
||||
|
||||
// load booster from path
|
||||
def apply(path: String): RDD[Booster] = {
|
||||
// TODO
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import scala.collection.immutable.HashMap
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
import com.typesafe.config.Config
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, DMatrixBuilder, Booster, ObjectiveTrait, EvalTrait}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
object XGBoost {
|
||||
|
||||
private var _sc: Option[SparkContext] = None
|
||||
|
||||
private def buildSparkContext(config: Config): SparkContext = {
|
||||
if (_sc.isEmpty) {
|
||||
// TODO:build SparkContext with the user configuration (cores per task, and cores per executor
|
||||
// (or total cores)
|
||||
// NOTE: currently Spark has limited support of configuration of core number in executors
|
||||
}
|
||||
_sc.get
|
||||
}
|
||||
|
||||
def train(config: Config, obj: ObjectiveTrait = null, eval: EvalTrait = null): RDD[Booster] = {
|
||||
val sc = buildSparkContext(config)
|
||||
val filePath = config.getString("inputPath") // configuration entry name to be fixed
|
||||
val numWorkers = config.getInt("numWorkers")
|
||||
val round = config.getInt("round")
|
||||
// TODO: build configuration map from config
|
||||
val xgBoostConfigMap = new HashMap[String, AnyRef]()
|
||||
sc.binaryFiles(filePath, numWorkers).mapPartitions {
|
||||
trainingFiles =>
|
||||
val boosters = new ListBuffer[Booster]
|
||||
// we assume one file per DMatrix
|
||||
for ((_, fileInStream) <- trainingFiles) {
|
||||
// TODO:
|
||||
// step1: build DMatrix from fileInStream.toArray (which returns a Array[Byte]) or
|
||||
// from a fileInStream.open() (which returns a DataInputStream)
|
||||
val dMatrix = DMatrixBuilder.buildDMatrixfromBinaryData(fileInStream.toArray())
|
||||
// step2: build a Booster
|
||||
// TODO: how to build watches list???
|
||||
boosters += SXGBoost.train(xgBoostConfigMap, dMatrix, round, watches = null, obj, eval)
|
||||
}
|
||||
// TODO
|
||||
boosters.iterator
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user